1919
2020from functools import cached_property
2121from pathlib import Path
22- from typing import TYPE_CHECKING , Any
22+ from typing import TYPE_CHECKING , Any , cast
2323
2424from kubernetes .client import CoreV1Api , CustomObjectsApi , models as k8s
2525
@@ -177,12 +177,7 @@ def create_job_name(self):
177177 return self ._set_name (updated_name )
178178
179179 @staticmethod
180- def _get_pod_identifying_label_string (labels ) -> str :
181- filtered_labels = {label_id : label for label_id , label in labels .items () if label_id != "try_number" }
182- return "," .join ([label_id + "=" + label for label_id , label in sorted (filtered_labels .items ())])
183-
184- @staticmethod
185- def create_labels_for_pod (context : dict | None = None , include_try_number : bool = True ) -> dict :
180+ def _get_ti_pod_labels (context : Context | None = None , include_try_number : bool = True ) -> dict [str , str ]:
186181 """
187182 Generate labels for the pod to track the pod in case of Operator crash.
188183
@@ -193,8 +188,9 @@ def create_labels_for_pod(context: dict | None = None, include_try_number: bool
193188 if not context :
194189 return {}
195190
196- ti = context ["ti" ]
197- run_id = context ["run_id" ]
191+ context_dict = cast (dict , context )
192+ ti = context_dict ["ti" ]
193+ run_id = context_dict ["run_id" ]
198194
199195 labels = {
200196 "dag_id" : ti .dag_id ,
@@ -213,8 +209,8 @@ def create_labels_for_pod(context: dict | None = None, include_try_number: bool
213209
214210 # In the case of sub dags this is just useful
215211 # TODO: Remove this when the minimum version of Airflow is bumped to 3.0
216- if getattr (context ["dag" ], "is_subdag" , False ):
217- labels ["parent_dag_id" ] = context ["dag" ].parent_dag .dag_id
212+ if getattr (context_dict ["dag" ], "is_subdag" , False ):
213+ labels ["parent_dag_id" ] = context_dict ["dag" ].parent_dag .dag_id
218214 # Ensure that label is valid for Kube,
219215 # and if not truncate/remove invalid chars and replace with short hash.
220216 for label_id , label in labels .items ():
@@ -235,9 +231,11 @@ def template_body(self):
235231 """Templated body for CustomObjectLauncher."""
236232 return self .manage_template_specs ()
237233
238- def find_spark_job (self , context ):
239- labels = self .create_labels_for_pod (context , include_try_number = False )
240- label_selector = self ._get_pod_identifying_label_string (labels ) + ",spark-role=driver"
234+ def find_spark_job (self , context , exclude_checked : bool = True ):
235+ label_selector = (
236+ self ._build_find_pod_label_selector (context , exclude_checked = exclude_checked )
237+ + ",spark-role=driver"
238+ )
241239 pod_list = self .client .list_namespaced_pod (self .namespace , label_selector = label_selector ).items
242240
243241 pod = None
0 commit comments