Skip to content

Commit 253834e

Browse files
authored
[FEAT] add driver/executor pod in Spark (#3016)
Signed-off-by: machichima <nary12321@gmail.com>
1 parent 08c01aa commit 253834e

File tree

5 files changed

+337
-20
lines changed

5 files changed

+337
-20
lines changed

flytekit/models/task.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,7 @@ def to_flyte_idl(self) -> _core_task.K8sPod:
10531053
metadata=self._metadata.to_flyte_idl() if self.metadata else None,
10541054
pod_spec=_json_format.Parse(_json.dumps(self.pod_spec), _struct.Struct()) if self.pod_spec else None,
10551055
data_config=self.data_config.to_flyte_idl() if self.data_config else None,
1056+
primary_container_name=self.primary_container_name,
10561057
)
10571058

10581059
@classmethod
@@ -1081,6 +1082,7 @@ def from_pod_template(cls, pod_template: "PodTemplate") -> "K8sPod":
10811082
return cls(
10821083
metadata=K8sObjectMetadata(labels=pod_template.labels, annotations=pod_template.annotations),
10831084
pod_spec=ApiClient().sanitize_for_serialization(pod_template.pod_spec),
1085+
primary_container_name=pod_template.primary_container_name,
10841086
)
10851087

10861088

plugins/flytekit-spark/flytekitplugins/spark/models.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from flytekit.exceptions import user as _user_exceptions
99
from flytekit.models import common as _common
10+
from flytekit.models.task import K8sPod
1011

1112

1213
class SparkType(enum.Enum):
@@ -27,6 +28,8 @@ def __init__(
2728
executor_path: str,
2829
databricks_conf: Optional[Dict[str, Dict[str, Dict]]] = None,
2930
databricks_instance: Optional[str] = None,
31+
driver_pod: Optional[K8sPod] = None,
32+
executor_pod: Optional[K8sPod] = None,
3033
):
3134
"""
3235
This defines a SparkJob target. It will execute the appropriate SparkJob.
@@ -47,6 +50,8 @@ def __init__(
4750
databricks_conf = {}
4851
self._databricks_conf = databricks_conf
4952
self._databricks_instance = databricks_instance
53+
self._driver_pod = driver_pod
54+
self._executor_pod = executor_pod
5055

5156
def with_overrides(
5257
self,
@@ -71,6 +76,8 @@ def with_overrides(
7176
hadoop_conf=new_hadoop_conf,
7277
databricks_conf=new_databricks_conf,
7378
databricks_instance=self.databricks_instance,
79+
driver_pod=self.driver_pod,
80+
executor_pod=self.executor_pod,
7481
executor_path=self.executor_path,
7582
)
7683

@@ -139,6 +146,22 @@ def databricks_instance(self) -> str:
139146
"""
140147
return self._databricks_instance
141148

149+
@property
150+
def driver_pod(self) -> K8sPod:
151+
"""
152+
Additional pod specs for driver pod.
153+
:rtype: K8sPod
154+
"""
155+
return self._driver_pod
156+
157+
@property
158+
def executor_pod(self) -> K8sPod:
159+
"""
160+
Additional pod specs for the worker node pods.
161+
:rtype: K8sPod
162+
"""
163+
return self._executor_pod
164+
142165
def to_flyte_idl(self):
143166
"""
144167
:rtype: flyteidl.plugins.spark_pb2.SparkJob
@@ -167,6 +190,8 @@ def to_flyte_idl(self):
167190
hadoopConf=self.hadoop_conf,
168191
databricksConf=databricks_conf,
169192
databricksInstance=self.databricks_instance,
193+
driverPod=self.driver_pod.to_flyte_idl() if self.driver_pod else None,
194+
executorPod=self.executor_pod.to_flyte_idl() if self.executor_pod else None,
170195
)
171196

172197
@classmethod
@@ -193,4 +218,6 @@ def from_flyte_idl(cls, pb2_object):
193218
executor_path=pb2_object.executorPath,
194219
databricks_conf=json_format.MessageToDict(pb2_object.databricksConf),
195220
databricks_instance=pb2_object.databricksInstance,
221+
driver_pod=pb2_object.driverPod,
222+
executor_pod=pb2_object.executorPod,
196223
)

plugins/flytekit-spark/flytekitplugins/spark/task.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger
1111
from flytekit.configuration import DefaultImages, SerializationSettings
1212
from flytekit.core.context_manager import ExecutionParameters
13+
from flytekit.core.pod_template import PRIMARY_CONTAINER_DEFAULT_NAME, PodTemplate
1314
from flytekit.extend import ExecutionState, TaskPlugins
1415
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
1516
from flytekit.image_spec import DefaultImageBuilder, ImageSpec
17+
from flytekit.models.task import K8sPod
1618

1719
from .models import SparkJob, SparkType
1820

@@ -26,17 +28,21 @@ class Spark(object):
2628
Use this to configure a SparkContext for a your task. Task's marked with this will automatically execute
2729
natively onto K8s as a distributed execution of spark
2830
29-
Args:
30-
spark_conf: Dictionary of spark config. The variables should match what spark expects
31-
hadoop_conf: Dictionary of hadoop conf. The variables should match a typical hadoop configuration for spark
32-
executor_path: Python binary executable to use for PySpark in driver and executor.
33-
applications_path: MainFile is the path to a bundled JAR, Python, or R file of the application to execute.
31+
Attributes:
32+
spark_conf (Optional[Dict[str, str]]): Spark configuration dictionary.
33+
hadoop_conf (Optional[Dict[str, str]]): Hadoop configuration dictionary.
34+
executor_path (Optional[str]): Path to the Python binary for PySpark execution.
35+
applications_path (Optional[str]): Path to the main application file.
36+
driver_pod (Optional[PodTemplate]): The pod template for the Spark driver pod.
37+
executor_pod (Optional[PodTemplate]): The pod template for the Spark executor pod.
3438
"""
3539

3640
spark_conf: Optional[Dict[str, str]] = None
3741
hadoop_conf: Optional[Dict[str, str]] = None
3842
executor_path: Optional[str] = None
3943
applications_path: Optional[str] = None
44+
driver_pod: Optional[PodTemplate] = None
45+
executor_pod: Optional[PodTemplate] = None
4046

4147
def __post_init__(self):
4248
if self.spark_conf is None:
@@ -172,6 +178,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
172178
executor_path=self._default_executor_path or settings.python_interpreter,
173179
main_class="",
174180
spark_type=SparkType.PYTHON,
181+
driver_pod=self.to_k8s_pod(self.task_config.driver_pod),
182+
executor_pod=self.to_k8s_pod(self.task_config.executor_pod),
175183
)
176184
if isinstance(self.task_config, (Databricks, DatabricksV2)):
177185
cfg = cast(DatabricksV2, self.task_config)
@@ -180,6 +188,27 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
180188

181189
return MessageToDict(job.to_flyte_idl())
182190

191+
def to_k8s_pod(self, pod_template: Optional[PodTemplate] = None) -> Optional[K8sPod]:
192+
"""
193+
Convert the podTemplate to K8sPod
194+
"""
195+
if pod_template is None:
196+
return None
197+
198+
task_primary_container_name = (
199+
self.pod_template.primary_container_name if self.pod_template else PRIMARY_CONTAINER_DEFAULT_NAME
200+
)
201+
202+
if pod_template.primary_container_name != task_primary_container_name:
203+
logger.warning(
204+
"Primary container name ('%s') set in spark differs from the one in @task ('%s'). "
205+
"The primary container name in @task will be overridden.",
206+
pod_template.primary_container_name,
207+
task_primary_container_name,
208+
)
209+
210+
return K8sPod.from_pod_template(pod_template)
211+
183212
def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
184213
import pyspark as _pyspark
185214

0 commit comments

Comments
 (0)