1010from flytekit import FlyteContextManager , PythonFunctionTask , lazy_module , logger
1111from flytekit .configuration import DefaultImages , SerializationSettings
1212from flytekit .core .context_manager import ExecutionParameters
13+ from flytekit .core .pod_template import PRIMARY_CONTAINER_DEFAULT_NAME , PodTemplate
1314from flytekit .extend import ExecutionState , TaskPlugins
1415from flytekit .extend .backend .base_agent import AsyncAgentExecutorMixin
1516from flytekit .image_spec import DefaultImageBuilder , ImageSpec
17+ from flytekit .models .task import K8sPod
1618
1719from .models import SparkJob , SparkType
1820
@@ -26,17 +28,21 @@ class Spark(object):
2628 Use this to configure a SparkContext for a your task. Task's marked with this will automatically execute
2729 natively onto K8s as a distributed execution of spark
2830
29- Args:
30- spark_conf: Dictionary of spark config. The variables should match what spark expects
31- hadoop_conf: Dictionary of hadoop conf. The variables should match a typical hadoop configuration for spark
32- executor_path: Python binary executable to use for PySpark in driver and executor.
33- applications_path: MainFile is the path to a bundled JAR, Python, or R file of the application to execute.
31+ Attributes:
32+ spark_conf (Optional[Dict[str, str]]): Spark configuration dictionary.
33+ hadoop_conf (Optional[Dict[str, str]]): Hadoop configuration dictionary.
34+ executor_path (Optional[str]): Path to the Python binary for PySpark execution.
35+ applications_path (Optional[str]): Path to the main application file.
36+ driver_pod (Optional[PodTemplate]): The pod template for the Spark driver pod.
37+ executor_pod (Optional[PodTemplate]): The pod template for the Spark executor pod.
3438 """
3539
3640 spark_conf : Optional [Dict [str , str ]] = None
3741 hadoop_conf : Optional [Dict [str , str ]] = None
3842 executor_path : Optional [str ] = None
3943 applications_path : Optional [str ] = None
44+ driver_pod : Optional [PodTemplate ] = None
45+ executor_pod : Optional [PodTemplate ] = None
4046
4147 def __post_init__ (self ):
4248 if self .spark_conf is None :
@@ -172,6 +178,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
172178 executor_path = self ._default_executor_path or settings .python_interpreter ,
173179 main_class = "" ,
174180 spark_type = SparkType .PYTHON ,
181+ driver_pod = self .to_k8s_pod (self .task_config .driver_pod ),
182+ executor_pod = self .to_k8s_pod (self .task_config .executor_pod ),
175183 )
176184 if isinstance (self .task_config , (Databricks , DatabricksV2 )):
177185 cfg = cast (DatabricksV2 , self .task_config )
@@ -180,6 +188,27 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
180188
181189 return MessageToDict (job .to_flyte_idl ())
182190
191+ def to_k8s_pod (self , pod_template : Optional [PodTemplate ] = None ) -> Optional [K8sPod ]:
192+ """
193+ Convert the podTemplate to K8sPod
194+ """
195+ if pod_template is None :
196+ return None
197+
198+ task_primary_container_name = (
199+ self .pod_template .primary_container_name if self .pod_template else PRIMARY_CONTAINER_DEFAULT_NAME
200+ )
201+
202+ if pod_template .primary_container_name != task_primary_container_name :
203+ logger .warning (
204+ "Primary container name ('%s') set in spark differs from the one in @task ('%s'). "
205+ "The primary container name in @task will be overridden." ,
206+ pod_template .primary_container_name ,
207+ task_primary_container_name ,
208+ )
209+
210+ return K8sPod .from_pod_template (pod_template )
211+
183212 def pre_execute (self , user_params : ExecutionParameters ) -> ExecutionParameters :
184213 import pyspark as _pyspark
185214
0 commit comments