@@ -359,3 +359,84 @@ def test_raise_exception_on_not_valid_branch_task_ids(self, dag_maker, branch_ta
359359 error_message = r"'branch_task_ids' must contain only valid task_ids. Invalid tasks found: .*"
360360 with pytest .raises (AirflowException , match = error_message ):
361361 SkipMixin ().skip_all_except (ti = ti1 , branch_task_ids = branch_task_ids )
362+
363+ @pytest .mark .skipif (not AIRFLOW_V_3_0_PLUS , reason = "Issue only exists in Airflow 3.x" )
364+ def test_ensure_tasks_includes_sensors_airflow_3x (self , dag_maker ):
365+ """Test that sensors (inheriting from airflow.sdk.BaseOperator) are properly handled by _ensure_tasks."""
366+ from airflow .providers .standard .utils .skipmixin import _ensure_tasks
367+ from airflow .sdk import BaseOperator as SDKBaseOperator
368+ from airflow .sdk .bases .sensor import BaseSensorOperator
369+
370+ class DummySensor (BaseSensorOperator ):
371+ def __init__ (self , ** kwargs ):
372+ super ().__init__ (** kwargs )
373+ self .timeout = 0
374+ self .poke_interval = 0
375+
376+ def poke (self , context ):
377+ return True
378+
379+ with dag_maker ("dag_test_sensor_skipping" ) as dag :
380+ regular_task = EmptyOperator (task_id = "regular_task" )
381+ sensor_task = DummySensor (task_id = "sensor_task" )
382+ downstream_task = EmptyOperator (task_id = "downstream_task" )
383+
384+ regular_task >> [sensor_task , downstream_task ]
385+
386+ dag_maker .create_dagrun (run_id = DEFAULT_DAG_RUN_ID )
387+
388+ downstream_nodes = dag .get_task ("regular_task" ).downstream_list
389+ task_list = _ensure_tasks (downstream_nodes )
390+
391+ # Verify both the regular operator and sensor are included
392+ task_ids = [t .task_id for t in task_list ]
393+ assert "sensor_task" in task_ids , "Sensor should be included in task list"
394+ assert "downstream_task" in task_ids , "Regular task should be included in task list"
395+ assert len (task_list ) == 2 , "Both tasks should be included"
396+
397+ # Also verify that the sensor is actually an instance of the correct BaseOperator
398+ sensor_in_list = next ((t for t in task_list if t .task_id == "sensor_task" ), None )
399+ assert sensor_in_list is not None , "Sensor task should be found in list"
400+ assert isinstance (sensor_in_list , SDKBaseOperator ), "Sensor should be instance of SDK BaseOperator"
401+
402+ @pytest .mark .skipif (not AIRFLOW_V_3_0_PLUS , reason = "Integration test for Airflow 3.x sensor skipping" )
403+ def test_skip_sensor_in_branching_scenario (self , dag_maker ):
404+ """Integration test: verify sensors are properly skipped by branching operators in Airflow 3.x."""
405+ from airflow .sdk .bases .sensor import BaseSensorOperator
406+
407+ # Create a dummy sensor for testing
408+ class DummySensor (BaseSensorOperator ):
409+ def __init__ (self , ** kwargs ):
410+ super ().__init__ (** kwargs )
411+ self .timeout = 0
412+ self .poke_interval = 0
413+
414+ def poke (self , context ):
415+ return True
416+
417+ with dag_maker ("dag_test_branch_sensor_skipping" ):
418+ branch_task = EmptyOperator (task_id = "branch_task" )
419+ regular_task = EmptyOperator (task_id = "regular_task" )
420+ sensor_task = DummySensor (task_id = "sensor_task" )
421+ branch_task >> [regular_task , sensor_task ]
422+
423+ dag_maker .create_dagrun (run_id = DEFAULT_DAG_RUN_ID )
424+
425+ dag_version = DagVersion .get_latest_version (branch_task .dag_id )
426+ ti_branch = TI (branch_task , run_id = DEFAULT_DAG_RUN_ID , dag_version_id = dag_version .id )
427+
428+ # Test skipping the sensor (follow regular_task branch)
429+ with pytest .raises (DownstreamTasksSkipped ) as exc_info :
430+ SkipMixin ().skip_all_except (ti = ti_branch , branch_task_ids = "regular_task" )
431+
432+ # Verify that the sensor task is properly marked for skipping
433+ skipped_tasks = set (exc_info .value .tasks )
434+ assert ("sensor_task" , - 1 ) in skipped_tasks , "Sensor task should be marked for skipping"
435+
436+ # Test skipping the regular task (follow sensor_task branch)
437+ with pytest .raises (DownstreamTasksSkipped ) as exc_info :
438+ SkipMixin ().skip_all_except (ti = ti_branch , branch_task_ids = "sensor_task" )
439+
440+ # Verify that the regular task is properly marked for skipping
441+ skipped_tasks = set (exc_info .value .tasks )
442+ assert ("regular_task" , - 1 ) in skipped_tasks , "Regular task should be marked for skipping"
0 commit comments