3434)
3535from airflow .models .trigger import TriggerFailureReason
3636from airflow .providers .standard .operators .empty import EmptyOperator
37- from airflow .sdk import timezone
37+ from airflow .sdk import TaskInstanceState , timezone
3838from airflow .sdk .bases .sensor import BaseSensorOperator , PokeReturnValue , poke_mode_only
3939from airflow .sdk .definitions .dag import DAG
4040from airflow .sdk .execution_time .comms import RescheduleTask , TaskRescheduleStartDate
4141from airflow .sdk .timezone import datetime
42- from airflow .utils .state import State
4342
4443if TYPE_CHECKING :
4544 from airflow .sdk .definitions .context import Context
@@ -172,22 +171,22 @@ def test_ok_with_reschedule(self, run_task, make_sensor, time_machine):
172171
173172 state , msg , _ = run_task (task = sensor )
174173
175- assert state == State .UP_FOR_RESCHEDULE
174+ assert state == TaskInstanceState .UP_FOR_RESCHEDULE
176175 assert msg .reschedule_date == date1 + timedelta (seconds = sensor .poke_interval )
177176
178177 # second poke returns False and task is re-scheduled
179178 time_machine .coordinates .shift (sensor .poke_interval )
180179 date2 = date1 + timedelta (seconds = sensor .poke_interval )
181180 state , msg , _ = run_task (task = sensor )
182181
183- assert state == State .UP_FOR_RESCHEDULE
182+ assert state == TaskInstanceState .UP_FOR_RESCHEDULE
184183 assert msg .reschedule_date == date2 + timedelta (seconds = sensor .poke_interval )
185184
186185 # third poke returns True and task succeeds
187186 time_machine .coordinates .shift (sensor .poke_interval )
188187 state , _ , _ = run_task (task = sensor )
189188
190- assert state == State .SUCCESS
189+ assert state == TaskInstanceState .SUCCESS
191190
192191 def test_fail_with_reschedule (self , run_task , make_sensor , time_machine , mock_supervisor_comms ):
193192 sensor = make_sensor (return_value = False , poke_interval = 10 , timeout = 5 , mode = "reschedule" )
@@ -198,7 +197,7 @@ def test_fail_with_reschedule(self, run_task, make_sensor, time_machine, mock_su
198197
199198 state , msg , _ = run_task (task = sensor )
200199
201- assert state == State .UP_FOR_RESCHEDULE
200+ assert state == TaskInstanceState .UP_FOR_RESCHEDULE
202201 assert msg .reschedule_date == date1 + timedelta (seconds = sensor .poke_interval )
203202
204203 # second poke returns False, timeout occurs
@@ -208,7 +207,7 @@ def test_fail_with_reschedule(self, run_task, make_sensor, time_machine, mock_su
208207 mock_supervisor_comms .send .return_value = TaskRescheduleStartDate (start_date = date1 )
209208 state , msg , error = run_task (task = sensor , context_update = {"task_reschedule_count" : 1 })
210209
211- assert state == State .FAILED
210+ assert state == TaskInstanceState .FAILED
212211 assert isinstance (error , AirflowSensorTimeout )
213212
214213 def test_soft_fail_with_reschedule (self , run_task , make_sensor , time_machine , mock_supervisor_comms ):
@@ -221,15 +220,15 @@ def test_soft_fail_with_reschedule(self, run_task, make_sensor, time_machine, mo
221220 time_machine .move_to (date1 , tick = False )
222221
223222 state , msg , _ = run_task (task = sensor )
224- assert state == State .UP_FOR_RESCHEDULE
223+ assert state == TaskInstanceState .UP_FOR_RESCHEDULE
225224
226225 # second poke returns False, timeout occurs
227226 time_machine .coordinates .shift (sensor .poke_interval )
228227
229228 # Mocking values from DB/API-server
230229 mock_supervisor_comms .send .return_value = TaskRescheduleStartDate (start_date = date1 )
231230 state , msg , _ = run_task (task = sensor , context_update = {"task_reschedule_count" : 1 })
232- assert state == State .SKIPPED
231+ assert state == TaskInstanceState .SKIPPED
233232
234233 def test_ok_with_reschedule_and_exponential_backoff (
235234 self , run_task , make_sensor , time_machine , mock_supervisor_comms
@@ -261,7 +260,7 @@ def run_duration():
261260 curr_date = curr_date + timedelta (seconds = new_interval )
262261 time_machine .coordinates .shift (new_interval )
263262 state , msg , _ = run_task (sensor , context_update = {"task_reschedule_count" : _poke_count })
264- assert state == State .UP_FOR_RESCHEDULE
263+ assert state == TaskInstanceState .UP_FOR_RESCHEDULE
265264 old_interval = new_interval
266265 new_interval = sensor ._get_next_poke_interval (task_start_date , run_duration , _poke_count )
267266 assert old_interval < new_interval # actual test
@@ -272,7 +271,7 @@ def run_duration():
272271 time_machine .coordinates .shift (new_interval )
273272
274273 state , msg , _ = run_task (sensor , context_update = {"task_reschedule_count" : false_count + 1 })
275- assert state == State .SUCCESS
274+ assert state == TaskInstanceState .SUCCESS
276275
277276 def test_invalid_mode (self ):
278277 with pytest .raises (AirflowException ):
@@ -291,22 +290,22 @@ def test_ok_with_custom_reschedule_exception(self, make_sensor, run_task):
291290 with time_machine .travel (date1 , tick = False ):
292291 state , msg , error = run_task (sensor )
293292
294- assert state == State .UP_FOR_RESCHEDULE
293+ assert state == TaskInstanceState .UP_FOR_RESCHEDULE
295294 assert isinstance (msg , RescheduleTask )
296295 assert msg .reschedule_date == date2
297296
298297 # second poke returns False and task is re-scheduled
299298 with time_machine .travel (date2 , tick = False ):
300299 state , msg , error = run_task (sensor )
301300
302- assert state == State .UP_FOR_RESCHEDULE
301+ assert state == TaskInstanceState .UP_FOR_RESCHEDULE
303302 assert isinstance (msg , RescheduleTask )
304303 assert msg .reschedule_date == date3
305304
306305 # third poke returns True and task succeeds
307306 with time_machine .travel (date3 , tick = False ):
308307 state , _ , _ = run_task (sensor )
309- assert state == State .SUCCESS
308+ assert state == TaskInstanceState .SUCCESS
310309
311310 def test_sensor_with_invalid_poke_interval (self ):
312311 negative_poke_interval = - 10
@@ -523,36 +522,36 @@ def _run_task():
523522 context_update = {"task_reschedule_count" : test_state ["task_reschedule_count" ]},
524523 )
525524
526- if state == State .UP_FOR_RESCHEDULE :
525+ if state == TaskInstanceState .UP_FOR_RESCHEDULE :
527526 test_state ["task_reschedule_count" ] += 1
528527 # Only set first_reschedule_date on the first successful reschedule
529528 if test_state ["first_reschedule_date" ] is None :
530529 test_state ["first_reschedule_date" ] = test_state ["current_time" ]
531- elif state == State .UP_FOR_RETRY :
530+ elif state == TaskInstanceState .UP_FOR_RETRY :
532531 test_state ["try_number" ] += 1
533532 return state , msg , error
534533
535534 # Phase 1: Initial execution until failure
536535 # First poke - should reschedule
537536 state , _ , _ = _run_task ()
538- assert state == State .UP_FOR_RESCHEDULE
537+ assert state == TaskInstanceState .UP_FOR_RESCHEDULE
539538
540539 # Second poke - should raise RuntimeError and retry
541540 test_state ["current_time" ] += timedelta (seconds = sensor .poke_interval )
542541 state , _ , error = _run_task ()
543- assert state == State .UP_FOR_RETRY
542+ assert state == TaskInstanceState .UP_FOR_RETRY
544543 assert isinstance (error , RuntimeError )
545544
546545 # Third poke - should reschedule again
547546 test_state ["current_time" ] += sensor .retry_delay + timedelta (seconds = 1 )
548547 state , _ , _ = _run_task ()
549- assert state == State .UP_FOR_RESCHEDULE
548+ assert state == TaskInstanceState .UP_FOR_RESCHEDULE
550549
551550 # Fourth poke - should timeout
552551 test_state ["current_time" ] += timedelta (seconds = sensor .poke_interval )
553552 state , _ , error = _run_task ()
554553 assert isinstance (error , AirflowSensorTimeout )
555- assert state == State .FAILED
554+ assert state == TaskInstanceState .FAILED
556555
557556 # Phase 2: After clearing the failed sensor
558557 # Reset supervisor comms to return None, simulating a fresh start after clearing
@@ -564,13 +563,13 @@ def _run_task():
564563 for _ in range (3 ):
565564 test_state ["current_time" ] += timedelta (seconds = sensor .poke_interval )
566565 state , _ , _ = _run_task ()
567- assert state == State .UP_FOR_RESCHEDULE
566+ assert state == TaskInstanceState .UP_FOR_RESCHEDULE
568567
569568 # Final poke - should timeout
570569 test_state ["current_time" ] += timedelta (seconds = sensor .poke_interval )
571570 state , _ , error = _run_task ()
572571 assert isinstance (error , AirflowSensorTimeout )
573- assert state == State .FAILED
572+ assert state == TaskInstanceState .FAILED
574573
575574 def test_sensor_with_xcom (self , make_sensor ):
576575 xcom_value = "TestValue"
@@ -615,7 +614,7 @@ def timeout():
615614 state , _ , error = run_task (task = task , dag_id = f"test_sensor_timeout_{ mode } _{ retries } " )
616615
617616 assert isinstance (error , AirflowSensorTimeout )
618- assert state == State .FAILED
617+ assert state == TaskInstanceState .FAILED
619618
620619
621620@poke_mode_only
0 commit comments