Skip to content

Commit 5068030

Browse files
authored
fix: spark operator label (apache#45353)
* fix: spark operator label * update spark operator * update spark kube * make ci happy * update test * format * format
1 parent 7e9b9ed commit 5068030

File tree

2 files changed

+41
-14
lines changed

2 files changed

+41
-14
lines changed

providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from functools import cached_property
2121
from pathlib import Path
22-
from typing import TYPE_CHECKING, Any
22+
from typing import TYPE_CHECKING, Any, cast
2323

2424
from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s
2525

@@ -177,12 +177,7 @@ def create_job_name(self):
177177
return self._set_name(updated_name)
178178

179179
@staticmethod
180-
def _get_pod_identifying_label_string(labels) -> str:
181-
filtered_labels = {label_id: label for label_id, label in labels.items() if label_id != "try_number"}
182-
return ",".join([label_id + "=" + label for label_id, label in sorted(filtered_labels.items())])
183-
184-
@staticmethod
185-
def create_labels_for_pod(context: dict | None = None, include_try_number: bool = True) -> dict:
180+
def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool = True) -> dict[str, str]:
186181
"""
187182
Generate labels for the pod to track the pod in case of Operator crash.
188183
@@ -193,8 +188,9 @@ def create_labels_for_pod(context: dict | None = None, include_try_number: bool
193188
if not context:
194189
return {}
195190

196-
ti = context["ti"]
197-
run_id = context["run_id"]
191+
context_dict = cast(dict, context)
192+
ti = context_dict["ti"]
193+
run_id = context_dict["run_id"]
198194

199195
labels = {
200196
"dag_id": ti.dag_id,
@@ -213,8 +209,8 @@ def create_labels_for_pod(context: dict | None = None, include_try_number: bool
213209

214210
# In the case of sub dags this is just useful
215211
# TODO: Remove this when the minimum version of Airflow is bumped to 3.0
216-
if getattr(context["dag"], "is_subdag", False):
217-
labels["parent_dag_id"] = context["dag"].parent_dag.dag_id
212+
if getattr(context_dict["dag"], "is_subdag", False):
213+
labels["parent_dag_id"] = context_dict["dag"].parent_dag.dag_id
218214
# Ensure that label is valid for Kube,
219215
# and if not truncate/remove invalid chars and replace with short hash.
220216
for label_id, label in labels.items():
@@ -235,9 +231,11 @@ def template_body(self):
235231
"""Templated body for CustomObjectLauncher."""
236232
return self.manage_template_specs()
237233

238-
def find_spark_job(self, context):
239-
labels = self.create_labels_for_pod(context, include_try_number=False)
240-
label_selector = self._get_pod_identifying_label_string(labels) + ",spark-role=driver"
234+
def find_spark_job(self, context, exclude_checked: bool = True):
235+
label_selector = (
236+
self._build_find_pod_label_selector(context, exclude_checked=exclude_checked)
237+
+ ",spark-role=driver"
238+
)
241239
pod_list = self.client.list_namespaced_pod(self.namespace, label_selector=label_selector).items
242240

243241
pod = None

providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,35 @@ def test_get_logs_from_driver(
701701
follow_logs=True,
702702
)
703703

704+
def test_find_custom_pod_labels(
705+
self,
706+
mock_create_namespaced_crd,
707+
mock_get_namespaced_custom_object_status,
708+
mock_cleanup,
709+
mock_create_job_name,
710+
mock_get_kube_client,
711+
mock_create_pod,
712+
mock_await_pod_start,
713+
mock_await_pod_completion,
714+
mock_fetch_requested_container_logs,
715+
data_file,
716+
):
717+
task_name = "test_find_custom_pod_labels"
718+
job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text())
719+
720+
mock_create_job_name.return_value = task_name
721+
op = SparkKubernetesOperator(
722+
template_spec=job_spec,
723+
kubernetes_conn_id="kubernetes_default_kube_config",
724+
task_id=task_name,
725+
get_logs=True,
726+
)
727+
context = create_context(op)
728+
op.execute(context)
729+
label_selector = op._build_find_pod_label_selector(context) + ",spark-role=driver"
730+
op.find_spark_job(context)
731+
mock_get_kube_client.list_namespaced_pod.assert_called_with("default", label_selector=label_selector)
732+
704733

705734
@pytest.mark.db_test
706735
def test_template_body_templating(create_task_instance_of_operator, session):

0 commit comments

Comments
 (0)