Skip to content

Commit e2289f9

Browse files
authored
Remove state dependency from airflow core in sdk (apache#55292)
Remove state dependency from airflow core in sdk
1 parent b0d396a commit e2289f9

File tree

3 files changed

+42
-36
lines changed

3 files changed

+42
-36
lines changed

task-sdk/src/airflow/sdk/definitions/dag.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
RemovedInAirflow4Warning,
4444
TaskNotFound,
4545
)
46-
from airflow.sdk import TriggerRule
46+
from airflow.sdk import TaskInstanceState, TriggerRule
4747
from airflow.sdk.bases.operator import BaseOperator
4848
from airflow.sdk.definitions._internal.node import validate_key
4949
from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
@@ -85,6 +85,15 @@
8585
"dag",
8686
]
8787

88+
FINISHED_STATES = frozenset(
89+
[
90+
TaskInstanceState.SUCCESS,
91+
TaskInstanceState.FAILED,
92+
TaskInstanceState.SKIPPED,
93+
TaskInstanceState.UPSTREAM_FAILED,
94+
TaskInstanceState.REMOVED,
95+
]
96+
)
8897

8998
DagStateChangeCallback = Callable[[Context], None]
9099
ScheduleInterval = None | str | timedelta | relativedelta
@@ -1166,10 +1175,9 @@ def test(
11661175
from airflow import settings
11671176
from airflow.configuration import secrets_backend_list
11681177
from airflow.models.dagrun import DagRun, get_or_create_dagrun
1169-
from airflow.sdk import DagRunState, TaskInstanceState, timezone
1178+
from airflow.sdk import DagRunState, timezone
11701179
from airflow.secrets.local_filesystem import LocalFilesystemBackend
11711180
from airflow.serialization.serialized_objects import SerializedDAG
1172-
from airflow.utils.state import State
11731181
from airflow.utils.types import DagRunTriggeredByType, DagRunType
11741182

11751183
exit_stack = ExitStack()
@@ -1291,7 +1299,7 @@ def test(
12911299
# triggerer may mark tasks scheduled so we read from DB
12921300
all_tis = set(dr.get_task_instances(session=session))
12931301
scheduled_tis = {x for x in all_tis if x.state == TaskInstanceState.SCHEDULED}
1294-
ids_unrunnable = {x for x in all_tis if x.state not in State.finished} - scheduled_tis
1302+
ids_unrunnable = {x for x in all_tis if x.state not in FINISHED_STATES} - scheduled_tis
12951303
if not scheduled_tis and ids_unrunnable:
12961304
log.warning("No tasks to run. unrunnable tasks: %s", ids_unrunnable)
12971305
time.sleep(1)
@@ -1331,7 +1339,7 @@ def test(
13311339
# Run the task locally
13321340
try:
13331341
if mark_success:
1334-
ti.set_state(State.SUCCESS)
1342+
ti.set_state(TaskInstanceState.SUCCESS)
13351343
log.info("[DAG TEST] Marking success for %s on %s", task, ti.logical_date)
13361344
else:
13371345
_run_task(ti=ti, task=task, run_triggerer=True)
@@ -1363,7 +1371,6 @@ def _run_task(
13631371
possible. This function is only meant for the `dag.test` function as a helper function.
13641372
"""
13651373
from airflow.sdk.module_loading import import_string
1366-
from airflow.utils.state import State
13671374

13681375
taskrun_result: TaskRunResult | None
13691376
log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, ti.map_index)
@@ -1378,7 +1385,7 @@ def _run_task(
13781385

13791386
# The API Server expects the task instance to be in QUEUED state before
13801387
# it is run.
1381-
ti.set_state(State.QUEUED)
1388+
ti.set_state(TaskInstanceState.QUEUED)
13821389
task_sdk_ti = TaskInstanceSDK(
13831390
id=ti.id,
13841391
task_id=ti.task_id,
@@ -1394,12 +1401,12 @@ def _run_task(
13941401
ti.set_state(taskrun_result.ti.state)
13951402
ti.task = create_scheduler_operator(taskrun_result.ti.task)
13961403

1397-
if ti.state == State.DEFERRED and isinstance(msg, DeferTask) and run_triggerer:
1404+
if ti.state == TaskInstanceState.DEFERRED and isinstance(msg, DeferTask) and run_triggerer:
13981405
from airflow.utils.session import create_session
13991406

14001407
# API Server expects the task instance to be in QUEUED state before
14011408
# resuming from deferral.
1402-
ti.set_state(State.QUEUED)
1409+
ti.set_state(TaskInstanceState.QUEUED)
14031410

14041411
log.info("[DAG TEST] running trigger in line")
14051412
trigger = import_string(msg.classpath)(**msg.trigger_kwargs)
@@ -1410,15 +1417,15 @@ def _run_task(
14101417

14111418
# Set the state to SCHEDULED so that the task can be resumed.
14121419
with create_session() as session:
1413-
ti.state = State.SCHEDULED
1420+
ti.state = TaskInstanceState.SCHEDULED
14141421
session.add(ti)
14151422
continue
14161423

14171424
break
14181425
except Exception:
14191426
log.exception("[DAG TEST] Error running task %s", ti)
1420-
if ti.state not in State.finished:
1421-
ti.set_state(State.FAILED)
1427+
if ti.state not in FINISHED_STATES:
1428+
ti.set_state(TaskInstanceState.FAILED)
14221429
taskrun_result = None
14231430
break
14241431
raise

task-sdk/tests/task_sdk/api/test_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
DagRunStateResponse,
4040
HITLDetailResponse,
4141
HITLUser,
42+
TerminalTIState,
4243
VariableResponse,
4344
XComResponse,
4445
)
@@ -52,7 +53,6 @@
5253
RescheduleTask,
5354
TaskRescheduleStartDate,
5455
)
55-
from airflow.utils.state import TerminalTIState
5656

5757
if TYPE_CHECKING:
5858
from time_machine import TimeMachineFixture

task-sdk/tests/task_sdk/bases/test_sensor.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,11 @@
3434
)
3535
from airflow.models.trigger import TriggerFailureReason
3636
from airflow.providers.standard.operators.empty import EmptyOperator
37-
from airflow.sdk import timezone
37+
from airflow.sdk import TaskInstanceState, timezone
3838
from airflow.sdk.bases.sensor import BaseSensorOperator, PokeReturnValue, poke_mode_only
3939
from airflow.sdk.definitions.dag import DAG
4040
from airflow.sdk.execution_time.comms import RescheduleTask, TaskRescheduleStartDate
4141
from airflow.sdk.timezone import datetime
42-
from airflow.utils.state import State
4342

4443
if TYPE_CHECKING:
4544
from airflow.sdk.definitions.context import Context
@@ -172,22 +171,22 @@ def test_ok_with_reschedule(self, run_task, make_sensor, time_machine):
172171

173172
state, msg, _ = run_task(task=sensor)
174173

175-
assert state == State.UP_FOR_RESCHEDULE
174+
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
176175
assert msg.reschedule_date == date1 + timedelta(seconds=sensor.poke_interval)
177176

178177
# second poke returns False and task is re-scheduled
179178
time_machine.coordinates.shift(sensor.poke_interval)
180179
date2 = date1 + timedelta(seconds=sensor.poke_interval)
181180
state, msg, _ = run_task(task=sensor)
182181

183-
assert state == State.UP_FOR_RESCHEDULE
182+
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
184183
assert msg.reschedule_date == date2 + timedelta(seconds=sensor.poke_interval)
185184

186185
# third poke returns True and task succeeds
187186
time_machine.coordinates.shift(sensor.poke_interval)
188187
state, _, _ = run_task(task=sensor)
189188

190-
assert state == State.SUCCESS
189+
assert state == TaskInstanceState.SUCCESS
191190

192191
def test_fail_with_reschedule(self, run_task, make_sensor, time_machine, mock_supervisor_comms):
193192
sensor = make_sensor(return_value=False, poke_interval=10, timeout=5, mode="reschedule")
@@ -198,7 +197,7 @@ def test_fail_with_reschedule(self, run_task, make_sensor, time_machine, mock_su
198197

199198
state, msg, _ = run_task(task=sensor)
200199

201-
assert state == State.UP_FOR_RESCHEDULE
200+
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
202201
assert msg.reschedule_date == date1 + timedelta(seconds=sensor.poke_interval)
203202

204203
# second poke returns False, timeout occurs
@@ -208,7 +207,7 @@ def test_fail_with_reschedule(self, run_task, make_sensor, time_machine, mock_su
208207
mock_supervisor_comms.send.return_value = TaskRescheduleStartDate(start_date=date1)
209208
state, msg, error = run_task(task=sensor, context_update={"task_reschedule_count": 1})
210209

211-
assert state == State.FAILED
210+
assert state == TaskInstanceState.FAILED
212211
assert isinstance(error, AirflowSensorTimeout)
213212

214213
def test_soft_fail_with_reschedule(self, run_task, make_sensor, time_machine, mock_supervisor_comms):
@@ -221,15 +220,15 @@ def test_soft_fail_with_reschedule(self, run_task, make_sensor, time_machine, mo
221220
time_machine.move_to(date1, tick=False)
222221

223222
state, msg, _ = run_task(task=sensor)
224-
assert state == State.UP_FOR_RESCHEDULE
223+
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
225224

226225
# second poke returns False, timeout occurs
227226
time_machine.coordinates.shift(sensor.poke_interval)
228227

229228
# Mocking values from DB/API-server
230229
mock_supervisor_comms.send.return_value = TaskRescheduleStartDate(start_date=date1)
231230
state, msg, _ = run_task(task=sensor, context_update={"task_reschedule_count": 1})
232-
assert state == State.SKIPPED
231+
assert state == TaskInstanceState.SKIPPED
233232

234233
def test_ok_with_reschedule_and_exponential_backoff(
235234
self, run_task, make_sensor, time_machine, mock_supervisor_comms
@@ -261,7 +260,7 @@ def run_duration():
261260
curr_date = curr_date + timedelta(seconds=new_interval)
262261
time_machine.coordinates.shift(new_interval)
263262
state, msg, _ = run_task(sensor, context_update={"task_reschedule_count": _poke_count})
264-
assert state == State.UP_FOR_RESCHEDULE
263+
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
265264
old_interval = new_interval
266265
new_interval = sensor._get_next_poke_interval(task_start_date, run_duration, _poke_count)
267266
assert old_interval < new_interval # actual test
@@ -272,7 +271,7 @@ def run_duration():
272271
time_machine.coordinates.shift(new_interval)
273272

274273
state, msg, _ = run_task(sensor, context_update={"task_reschedule_count": false_count + 1})
275-
assert state == State.SUCCESS
274+
assert state == TaskInstanceState.SUCCESS
276275

277276
def test_invalid_mode(self):
278277
with pytest.raises(AirflowException):
@@ -291,22 +290,22 @@ def test_ok_with_custom_reschedule_exception(self, make_sensor, run_task):
291290
with time_machine.travel(date1, tick=False):
292291
state, msg, error = run_task(sensor)
293292

294-
assert state == State.UP_FOR_RESCHEDULE
293+
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
295294
assert isinstance(msg, RescheduleTask)
296295
assert msg.reschedule_date == date2
297296

298297
# second poke returns False and task is re-scheduled
299298
with time_machine.travel(date2, tick=False):
300299
state, msg, error = run_task(sensor)
301300

302-
assert state == State.UP_FOR_RESCHEDULE
301+
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
303302
assert isinstance(msg, RescheduleTask)
304303
assert msg.reschedule_date == date3
305304

306305
# third poke returns True and task succeeds
307306
with time_machine.travel(date3, tick=False):
308307
state, _, _ = run_task(sensor)
309-
assert state == State.SUCCESS
308+
assert state == TaskInstanceState.SUCCESS
310309

311310
def test_sensor_with_invalid_poke_interval(self):
312311
negative_poke_interval = -10
@@ -523,36 +522,36 @@ def _run_task():
523522
context_update={"task_reschedule_count": test_state["task_reschedule_count"]},
524523
)
525524

526-
if state == State.UP_FOR_RESCHEDULE:
525+
if state == TaskInstanceState.UP_FOR_RESCHEDULE:
527526
test_state["task_reschedule_count"] += 1
528527
# Only set first_reschedule_date on the first successful reschedule
529528
if test_state["first_reschedule_date"] is None:
530529
test_state["first_reschedule_date"] = test_state["current_time"]
531-
elif state == State.UP_FOR_RETRY:
530+
elif state == TaskInstanceState.UP_FOR_RETRY:
532531
test_state["try_number"] += 1
533532
return state, msg, error
534533

535534
# Phase 1: Initial execution until failure
536535
# First poke - should reschedule
537536
state, _, _ = _run_task()
538-
assert state == State.UP_FOR_RESCHEDULE
537+
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
539538

540539
# Second poke - should raise RuntimeError and retry
541540
test_state["current_time"] += timedelta(seconds=sensor.poke_interval)
542541
state, _, error = _run_task()
543-
assert state == State.UP_FOR_RETRY
542+
assert state == TaskInstanceState.UP_FOR_RETRY
544543
assert isinstance(error, RuntimeError)
545544

546545
# Third poke - should reschedule again
547546
test_state["current_time"] += sensor.retry_delay + timedelta(seconds=1)
548547
state, _, _ = _run_task()
549-
assert state == State.UP_FOR_RESCHEDULE
548+
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
550549

551550
# Fourth poke - should timeout
552551
test_state["current_time"] += timedelta(seconds=sensor.poke_interval)
553552
state, _, error = _run_task()
554553
assert isinstance(error, AirflowSensorTimeout)
555-
assert state == State.FAILED
554+
assert state == TaskInstanceState.FAILED
556555

557556
# Phase 2: After clearing the failed sensor
558557
# Reset supervisor comms to return None, simulating a fresh start after clearing
@@ -564,13 +563,13 @@ def _run_task():
564563
for _ in range(3):
565564
test_state["current_time"] += timedelta(seconds=sensor.poke_interval)
566565
state, _, _ = _run_task()
567-
assert state == State.UP_FOR_RESCHEDULE
566+
assert state == TaskInstanceState.UP_FOR_RESCHEDULE
568567

569568
# Final poke - should timeout
570569
test_state["current_time"] += timedelta(seconds=sensor.poke_interval)
571570
state, _, error = _run_task()
572571
assert isinstance(error, AirflowSensorTimeout)
573-
assert state == State.FAILED
572+
assert state == TaskInstanceState.FAILED
574573

575574
def test_sensor_with_xcom(self, make_sensor):
576575
xcom_value = "TestValue"
@@ -615,7 +614,7 @@ def timeout():
615614
state, _, error = run_task(task=task, dag_id=f"test_sensor_timeout_{mode}_{retries}")
616615

617616
assert isinstance(error, AirflowSensorTimeout)
618-
assert state == State.FAILED
617+
assert state == TaskInstanceState.FAILED
619618

620619

621620
@poke_mode_only

0 commit comments

Comments
 (0)