Skip to content

Commit fc5410c

Browse files
authored
Fix sensor skipping in Airflow 3.x branching operators (#53455)
In Airflow 3.x, sensors inherit from airflow.sdk.BaseOperator instead of airflow.models.BaseOperator. The _ensure_tasks function in SkipMixin was only checking for the models BaseOperator, causing sensors to be filtered out and not properly skipped by branching operators like BranchSQLOperator. Updated the import logic to use the correct SDK BaseOperator for Airflow 3.x and added comprehensive tests to verify sensors are properly included in branching skip operations. Fixes #52219
1 parent df5c949 commit fc5410c

File tree

3 files changed

+131
-2
lines changed

3 files changed

+131
-2
lines changed

providers/standard/src/airflow/providers/standard/utils/skipmixin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing import TYPE_CHECKING
2323

2424
from airflow.exceptions import AirflowException
25-
from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_PLUS
25+
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
2626
from airflow.utils.log.logging_mixin import LoggingMixin
2727

2828
if TYPE_CHECKING:
@@ -40,7 +40,7 @@
4040

4141

4242
def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]:
43-
if AIRFLOW_V_3_1_PLUS:
43+
if AIRFLOW_V_3_0_PLUS:
4444
from airflow.sdk import BaseOperator
4545
from airflow.sdk.definitions.mappedoperator import MappedOperator
4646
else:

providers/standard/tests/unit/standard/operators/test_python.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,54 @@ def test_xcom_push_skipped_tasks(self):
845845
"skipped": ["empty_task"]
846846
}
847847

848+
@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Airflow 2 implementation is different")
849+
def test_short_circuit_operator_skips_sensors(self):
850+
"""Test that ShortCircuitOperator properly skips sensors in Airflow 3.x."""
851+
from airflow.sdk.bases.sensor import BaseSensorOperator
852+
853+
# Create a sensor similar to S3FileSensor to reproduce the issue
854+
class CustomS3Sensor(BaseSensorOperator):
855+
def __init__(self, bucket_name: str, object_key: str, **kwargs):
856+
super().__init__(**kwargs)
857+
self.bucket_name = bucket_name
858+
self.object_key = object_key
859+
self.timeout = 0
860+
self.poke_interval = 0
861+
862+
def poke(self, context):
863+
# Simulate sensor logic
864+
return True
865+
866+
with self.dag_maker(self.dag_id):
867+
# ShortCircuit that evaluates to False (should skip all downstream)
868+
short_circuit = ShortCircuitOperator(
869+
task_id="check_dis_is_mon_to_fri_not_holiday",
870+
python_callable=lambda: False, # This causes skipping
871+
)
872+
873+
sensor_task = CustomS3Sensor(
874+
task_id="wait_for_ticker_to_secid_lookup_s3_file",
875+
bucket_name="test-bucket",
876+
object_key="ticker_to_secid_lookup.csv",
877+
)
878+
879+
short_circuit >> sensor_task
880+
881+
dr = self.dag_maker.create_dagrun()
882+
883+
self.dag_maker.run_ti("check_dis_is_mon_to_fri_not_holiday", dr)
884+
885+
# Verify the sensor is included in the skip list by checking XCom
886+
# (this was the bug - sensors were not being included in skip list)
887+
tis = dr.get_task_instances()
888+
xcom_data = tis[0].xcom_pull(task_ids="check_dis_is_mon_to_fri_not_holiday", key="skipmixin_key")
889+
890+
assert xcom_data is not None, "XCom data should exist"
891+
skipped_task_ids = set(xcom_data.get("skipped", []))
892+
assert "wait_for_ticker_to_secid_lookup_s3_file" in skipped_task_ids, (
893+
"Sensor should be skipped by ShortCircuitOperator"
894+
)
895+
848896

849897
virtualenv_string_args: list[str] = []
850898

providers/standard/tests/unit/standard/utils/test_skipmixin.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,3 +359,84 @@ def test_raise_exception_on_not_valid_branch_task_ids(self, dag_maker, branch_ta
359359
error_message = r"'branch_task_ids' must contain only valid task_ids. Invalid tasks found: .*"
360360
with pytest.raises(AirflowException, match=error_message):
361361
SkipMixin().skip_all_except(ti=ti1, branch_task_ids=branch_task_ids)
362+
363+
@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Issue only exists in Airflow 3.x")
364+
def test_ensure_tasks_includes_sensors_airflow_3x(self, dag_maker):
365+
"""Test that sensors (inheriting from airflow.sdk.BaseOperator) are properly handled by _ensure_tasks."""
366+
from airflow.providers.standard.utils.skipmixin import _ensure_tasks
367+
from airflow.sdk import BaseOperator as SDKBaseOperator
368+
from airflow.sdk.bases.sensor import BaseSensorOperator
369+
370+
class DummySensor(BaseSensorOperator):
371+
def __init__(self, **kwargs):
372+
super().__init__(**kwargs)
373+
self.timeout = 0
374+
self.poke_interval = 0
375+
376+
def poke(self, context):
377+
return True
378+
379+
with dag_maker("dag_test_sensor_skipping") as dag:
380+
regular_task = EmptyOperator(task_id="regular_task")
381+
sensor_task = DummySensor(task_id="sensor_task")
382+
downstream_task = EmptyOperator(task_id="downstream_task")
383+
384+
regular_task >> [sensor_task, downstream_task]
385+
386+
dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID)
387+
388+
downstream_nodes = dag.get_task("regular_task").downstream_list
389+
task_list = _ensure_tasks(downstream_nodes)
390+
391+
# Verify both the regular operator and sensor are included
392+
task_ids = [t.task_id for t in task_list]
393+
assert "sensor_task" in task_ids, "Sensor should be included in task list"
394+
assert "downstream_task" in task_ids, "Regular task should be included in task list"
395+
assert len(task_list) == 2, "Both tasks should be included"
396+
397+
# Also verify that the sensor is actually an instance of the correct BaseOperator
398+
sensor_in_list = next((t for t in task_list if t.task_id == "sensor_task"), None)
399+
assert sensor_in_list is not None, "Sensor task should be found in list"
400+
assert isinstance(sensor_in_list, SDKBaseOperator), "Sensor should be instance of SDK BaseOperator"
401+
402+
@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Integration test for Airflow 3.x sensor skipping")
403+
def test_skip_sensor_in_branching_scenario(self, dag_maker):
404+
"""Integration test: verify sensors are properly skipped by branching operators in Airflow 3.x."""
405+
from airflow.sdk.bases.sensor import BaseSensorOperator
406+
407+
# Create a dummy sensor for testing
408+
class DummySensor(BaseSensorOperator):
409+
def __init__(self, **kwargs):
410+
super().__init__(**kwargs)
411+
self.timeout = 0
412+
self.poke_interval = 0
413+
414+
def poke(self, context):
415+
return True
416+
417+
with dag_maker("dag_test_branch_sensor_skipping"):
418+
branch_task = EmptyOperator(task_id="branch_task")
419+
regular_task = EmptyOperator(task_id="regular_task")
420+
sensor_task = DummySensor(task_id="sensor_task")
421+
branch_task >> [regular_task, sensor_task]
422+
423+
dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID)
424+
425+
dag_version = DagVersion.get_latest_version(branch_task.dag_id)
426+
ti_branch = TI(branch_task, run_id=DEFAULT_DAG_RUN_ID, dag_version_id=dag_version.id)
427+
428+
# Test skipping the sensor (follow regular_task branch)
429+
with pytest.raises(DownstreamTasksSkipped) as exc_info:
430+
SkipMixin().skip_all_except(ti=ti_branch, branch_task_ids="regular_task")
431+
432+
# Verify that the sensor task is properly marked for skipping
433+
skipped_tasks = set(exc_info.value.tasks)
434+
assert ("sensor_task", -1) in skipped_tasks, "Sensor task should be marked for skipping"
435+
436+
# Test skipping the regular task (follow sensor_task branch)
437+
with pytest.raises(DownstreamTasksSkipped) as exc_info:
438+
SkipMixin().skip_all_except(ti=ti_branch, branch_task_ids="sensor_task")
439+
440+
# Verify that the regular task is properly marked for skipping
441+
skipped_tasks = set(exc_info.value.tasks)
442+
assert ("regular_task", -1) in skipped_tasks, "Regular task should be marked for skipping"

0 commit comments

Comments
 (0)