Skip to content

Commit 20c7002

Browse files
Make start_date in Context nullable (#58175)
* Make start_date in Context nullable If a deadline is missed before a DagRun is started (could be queued for a very long time), there would be an exception when creating the Context because start_date is currently non-nullable. * Add cadwyn migration for nullable DagRun.start_date * more explicit downgrade Co-authored-by: Amogh Desai <amoghrajesh1999@gmail.com>" * merge conflict static checks --------- Co-authored-by: ferruzzi <ferruzzi@amazon.com>
1 parent 27426e4 commit 20c7002

File tree

9 files changed

+211
-6
lines changed

9 files changed

+211
-6
lines changed

airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ class DagRun(StrictBaseModel):
296296
data_interval_start: UtcDateTime | None
297297
data_interval_end: UtcDateTime | None
298298
run_after: UtcDateTime
299-
start_date: UtcDateTime
299+
start_date: UtcDateTime | None
300300
end_date: UtcDateTime | None
301301
clear_number: int = 0
302302
run_type: DagRunType

airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,20 @@
3535
)
3636
from airflow.api_fastapi.execution_api.versions.v2026_03_31 import (
3737
AddNoteField,
38+
MakeDagRunStartDateNullable,
3839
ModifyDeferredTaskKwargsToJsonValue,
3940
RemoveUpstreamMapIndexesField,
4041
)
4142

4243
bundle = VersionBundle(
4344
HeadVersion(),
44-
Version("2026-03-31", ModifyDeferredTaskKwargsToJsonValue, RemoveUpstreamMapIndexesField, AddNoteField),
45+
Version(
46+
"2026-03-31",
47+
MakeDagRunStartDateNullable,
48+
ModifyDeferredTaskKwargsToJsonValue,
49+
RemoveUpstreamMapIndexesField,
50+
AddNoteField,
51+
),
4552
Version("2025-12-08", MovePreviousRunEndpoint, AddDagRunDetailEndpoint),
4653
Version("2025-11-07", AddPartitionKeyField),
4754
Version("2025-11-05", AddTriggeringUserNameField),

airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from cadwyn import ResponseInfo, VersionChange, convert_response_to_previous_version_for, schema
2323

24+
from airflow.api_fastapi.common.types import UtcDateTime
2425
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
2526
DagRun,
2627
TIDeferredStatePayload,
@@ -68,3 +69,29 @@ def remove_note_field(response: ResponseInfo) -> None: # type: ignore[misc]
6869
"""Remove note field for older API versions."""
6970
if "dag_run" in response.body and isinstance(response.body["dag_run"], dict):
7071
response.body["dag_run"].pop("note", None)
72+
73+
74+
class MakeDagRunStartDateNullable(VersionChange):
75+
"""Make DagRun.start_date field nullable for runs that haven't started yet."""
76+
77+
description = __doc__
78+
79+
instructions_to_migrate_to_previous_version = (schema(DagRun).field("start_date").had(type=UtcDateTime),)
80+
81+
@convert_response_to_previous_version_for(TIRunContext) # type: ignore[arg-type]
82+
def ensure_start_date_in_ti_run_context(response: ResponseInfo) -> None: # type: ignore[misc]
83+
"""
84+
Ensure start_date is never None in DagRun for previous API versions.
85+
86+
Older Task SDK clients expect start_date to be non-nullable. When the
87+
DagRun hasn't started yet (e.g. queued), fall back to run_after.
88+
"""
89+
dag_run = response.body.get("dag_run")
90+
if isinstance(dag_run, dict) and dag_run.get("start_date") is None:
91+
dag_run["start_date"] = dag_run.get("run_after")
92+
93+
@convert_response_to_previous_version_for(DagRun) # type: ignore[arg-type]
94+
def ensure_start_date_in_dag_run(response: ResponseInfo) -> None: # type: ignore[misc]
95+
"""Ensure start_date is never None in direct DagRun responses for previous API versions."""
96+
if response.body.get("start_date") is None:
97+
response.body["start_date"] = response.body.get("run_after")
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
import pytest
21+
22+
from airflow._shared.timezones import timezone
23+
from airflow.utils.state import DagRunState, State
24+
25+
from tests_common.test_utils.db import clear_db_runs
26+
27+
pytestmark = pytest.mark.db_test
28+
29+
TIMESTAMP_STR = "2024-09-30T12:00:00Z"
30+
TIMESTAMP = timezone.parse(TIMESTAMP_STR)
31+
32+
RUN_PATCH_BODY = {
33+
"state": "running",
34+
"hostname": "test-hostname",
35+
"unixname": "test-user",
36+
"pid": 12345,
37+
"start_date": TIMESTAMP_STR,
38+
}
39+
40+
41+
@pytest.fixture
42+
def old_ver_client(client):
43+
"""Client configured to use API version before start_date nullable change."""
44+
client.headers["Airflow-API-Version"] = "2025-12-08"
45+
return client
46+
47+
48+
class TestDagRunStartDateNullableBackwardCompat:
49+
"""Test that older API versions get a non-null start_date fallback."""
50+
51+
@pytest.fixture(autouse=True)
52+
def _freeze_time(self, time_machine):
53+
time_machine.move_to(TIMESTAMP_STR, tick=False)
54+
55+
def setup_method(self):
56+
clear_db_runs()
57+
58+
def teardown_method(self):
59+
clear_db_runs()
60+
61+
def test_old_version_gets_run_after_when_start_date_is_null(
62+
self,
63+
old_ver_client,
64+
session,
65+
create_task_instance,
66+
):
67+
ti = create_task_instance(
68+
task_id="test_start_date_nullable",
69+
state=State.QUEUED,
70+
dagrun_state=DagRunState.QUEUED,
71+
session=session,
72+
start_date=TIMESTAMP,
73+
)
74+
ti.dag_run.start_date = None # DagRun has not started yet
75+
session.commit()
76+
77+
response = old_ver_client.patch(f"/execution/task-instances/{ti.id}/run", json=RUN_PATCH_BODY)
78+
dag_run = response.json()["dag_run"]
79+
80+
assert response.status_code == 200
81+
assert dag_run["start_date"] is not None
82+
assert dag_run["start_date"] == dag_run["run_after"]
83+
84+
def test_head_version_allows_null_start_date(
85+
self,
86+
client,
87+
session,
88+
create_task_instance,
89+
):
90+
ti = create_task_instance(
91+
task_id="test_start_date_null_head",
92+
state=State.QUEUED,
93+
dagrun_state=DagRunState.QUEUED,
94+
session=session,
95+
start_date=TIMESTAMP,
96+
)
97+
ti.dag_run.start_date = None # DagRun has not started yet
98+
session.commit()
99+
100+
response = client.patch(f"/execution/task-instances/{ti.id}/run", json=RUN_PATCH_BODY)
101+
dag_run = response.json()["dag_run"]
102+
103+
assert response.status_code == 200
104+
assert dag_run["start_date"] is None
105+
106+
def test_old_version_preserves_real_start_date(
107+
self,
108+
old_ver_client,
109+
session,
110+
create_task_instance,
111+
):
112+
ti = create_task_instance(
113+
task_id="test_start_date_preserved",
114+
state=State.QUEUED,
115+
dagrun_state=DagRunState.RUNNING,
116+
session=session,
117+
start_date=TIMESTAMP,
118+
)
119+
assert ti.dag_run.start_date == TIMESTAMP # DagRun has already started
120+
session.commit()
121+
122+
response = old_ver_client.patch(f"/execution/task-instances/{ti.id}/run", json=RUN_PATCH_BODY)
123+
dag_run = response.json()["dag_run"]
124+
125+
assert response.status_code == 200
126+
assert dag_run["start_date"] is not None, "start_date should not be None when DagRun has started"
127+
assert dag_run["start_date"] == TIMESTAMP.isoformat().replace("+00:00", "Z")

task-sdk/src/airflow/sdk/api/datamodels/_generated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,7 @@ class DagRun(BaseModel):
621621
data_interval_start: Annotated[AwareDatetime | None, Field(title="Data Interval Start")] = None
622622
data_interval_end: Annotated[AwareDatetime | None, Field(title="Data Interval End")] = None
623623
run_after: Annotated[AwareDatetime, Field(title="Run After")]
624-
start_date: Annotated[AwareDatetime, Field(title="Start Date")]
624+
start_date: Annotated[AwareDatetime | None, Field(title="Start Date")] = None
625625
end_date: Annotated[AwareDatetime | None, Field(title="End Date")] = None
626626
clear_number: Annotated[int | None, Field(title="Clear Number")] = 0
627627
run_type: DagRunType

task-sdk/src/airflow/sdk/definitions/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class Context(TypedDict, total=False):
6565
prev_end_date_success: NotRequired[DateTime | None]
6666
reason: NotRequired[str | None]
6767
run_id: str
68-
start_date: DateTime
68+
start_date: DateTime | None
6969
# TODO: Remove Operator from below once we have MappedOperator to the Task SDK
7070
# and once we can remove context related code from the Scheduler/models.TaskInstance
7171
task: BaseOperator | Operator

task-sdk/src/airflow/sdk/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class DagRunProtocol(Protocol):
8080
logical_date: AwareDatetime | None
8181
data_interval_start: AwareDatetime | None
8282
data_interval_end: AwareDatetime | None
83-
start_date: AwareDatetime
83+
start_date: AwareDatetime | None
8484
end_date: AwareDatetime | None
8585
run_type: Any
8686
run_after: AwareDatetime

task-sdk/tests/task_sdk/execution_time/test_context.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import pytest
2424

2525
from airflow.sdk import BaseOperator, get_current_context, timezone
26-
from airflow.sdk.api.datamodels._generated import AssetEventResponse, AssetResponse
26+
from airflow.sdk.api.datamodels._generated import AssetEventResponse, AssetResponse, DagRun
2727
from airflow.sdk.bases.xcom import BaseXCom
2828
from airflow.sdk.definitions.asset import (
2929
Asset,
@@ -862,6 +862,34 @@ def test_source_task_instance_xcom_pull(self, sample_inlet_evnets_accessor, mock
862862
]
863863

864864

865+
class TestDagRunStartDateNullable:
866+
"""Test that DagRun and TIRunContext accept start_date=None (queued runs that haven't started)."""
867+
868+
def test_dag_run_model_accepts_null_start_date(self):
869+
"""DagRun datamodel should accept start_date=None for runs that haven't started yet."""
870+
dag_run = DagRun(
871+
dag_id="test_dag",
872+
run_id="test_run",
873+
logical_date="2024-12-01T01:00:00Z",
874+
data_interval_start="2024-12-01T00:00:00Z",
875+
data_interval_end="2024-12-01T01:00:00Z",
876+
start_date=None,
877+
run_after="2024-12-01T01:00:00Z",
878+
run_type="manual",
879+
state="queued",
880+
conf=None,
881+
consumed_asset_events=[],
882+
)
883+
884+
assert dag_run.start_date is None
885+
886+
def test_ti_run_context_with_null_start_date(self, make_ti_context):
887+
"""TIRunContext should be constructable when the DagRun has start_date=None."""
888+
ti_context = make_ti_context(start_date=None)
889+
890+
assert ti_context.dag_run.start_date is None
891+
892+
865893
class TestAsyncGetConnection:
866894
"""Test async connection retrieval with secrets backends."""
867895

0 commit comments

Comments
 (0)