Skip to content

Commit 1cda358

Browse files
authored
Clean up SDK references in airflow.models.expandinput (#59815)
1 parent 12f6fbd commit 1cda358

File tree

11 files changed

+101
-121
lines changed

11 files changed

+101
-121
lines changed

airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@
6464
from airflow.models.asset import AssetActive
6565
from airflow.models.dag import DagModel
6666
from airflow.models.dagrun import DagRun as DR
67+
from airflow.models.expandinput import NotFullyPopulated
6768
from airflow.models.taskinstance import TaskInstance as TI, _stop_remaining_tasks
6869
from airflow.models.taskreschedule import TaskReschedule
6970
from airflow.models.trigger import Trigger
7071
from airflow.models.xcom import XComModel
71-
from airflow.sdk.definitions._internal.expandinput import NotFullyPopulated
7272
from airflow.serialization.definitions.assets import SerializedAsset, SerializedAssetUniqueKey
7373
from airflow.serialization.definitions.dag import SerializedDAG
7474
from airflow.task.trigger_rule import TriggerRule

airflow-core/src/airflow/models/expandinput.py

Lines changed: 67 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,45 +24,83 @@
2424

2525
import attrs
2626

27-
from airflow.sdk.definitions._internal.expandinput import (
28-
DictOfListsExpandInput,
29-
ListOfDictsExpandInput,
30-
MappedArgument,
31-
NotFullyPopulated,
32-
OperatorExpandArgument,
33-
OperatorExpandKwargsArgument,
34-
is_mappable,
35-
)
36-
3727
if TYPE_CHECKING:
38-
from typing import TypeGuard
28+
from collections.abc import Mapping, Sequence
29+
from typing import TypeAlias
3930

4031
from sqlalchemy.orm import Session
32+
from typing_extensions import TypeIs
4133

4234
from airflow.serialization.definitions.mappedoperator import Operator
4335
from airflow.serialization.definitions.xcom_arg import SchedulerXComArg
4436

37+
ExpandArgument: TypeAlias = "SchedulerMappedArgument" | SchedulerXComArg | Sequence | Mapping[str, Any]
38+
ExpandKwargsArgument: TypeAlias = SchedulerXComArg | Sequence[SchedulerXComArg | Mapping[str, Any]]
39+
4540

4641
__all__ = [
47-
"DictOfListsExpandInput",
48-
"ListOfDictsExpandInput",
49-
"MappedArgument",
5042
"NotFullyPopulated",
51-
"OperatorExpandArgument",
52-
"OperatorExpandKwargsArgument",
53-
"is_mappable",
43+
"SchedulerMappedArgument",
44+
"SchedulerDictOfListsExpandInput",
45+
"SchedulerListOfDictsExpandInput",
5446
]
5547

5648

57-
def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | SchedulerXComArg]:
49+
class NotFullyPopulated(RuntimeError):
50+
"""
51+
Raise when mapped length cannot be calculated due to incomplete metadata.
52+
53+
This is generally due to not all upstream tasks have been completed (or in
54+
parse-time length calculations, when any upstream has runtime dependencies
55+
on mapped length) when the function is called.
56+
"""
57+
58+
def __init__(self, missing: set[str]) -> None:
59+
self.missing = missing
60+
61+
def __str__(self) -> str:
62+
keys = ", ".join(repr(k) for k in sorted(self.missing))
63+
return f"Failed to populate all mapping metadata; missing: {keys}"
64+
65+
66+
def _needs_run_time_resolution(v: ExpandArgument) -> TypeIs[SchedulerMappedArgument | SchedulerXComArg]:
5867
from airflow.serialization.definitions.xcom_arg import SchedulerXComArg
5968

60-
return isinstance(v, (MappedArgument, SchedulerXComArg))
69+
return isinstance(v, (SchedulerMappedArgument, SchedulerXComArg))
70+
71+
72+
@attrs.define(kw_only=True)
73+
class SchedulerMappedArgument:
74+
"""
75+
Stand-in stub for task-group-mapping arguments.
76+
77+
This corresponds on SDK's ``MappedArgument``, which is created when
78+
dynamically mapping a task group, and an argument used to dynamic-map is
79+
passed into a task inside the group.
80+
81+
This value is not currently used anywhere in the scheduler since nested
82+
dynamic mapping is not supported (i.e. using this value to further expand
83+
an operator inside a mapped task group), but this is implemented so the
84+
value is displayed better in the UI.
85+
"""
86+
87+
_input: SchedulerExpandInput = attrs.field()
88+
_key: str
89+
90+
def iter_references(self) -> Iterable[tuple[Operator, str]]:
91+
yield from self._input.iter_references()
6192

6293

6394
@attrs.define
6495
class SchedulerDictOfListsExpandInput:
65-
value: dict
96+
"""
97+
Serialized storage of a mapped operator's mapped kwargs.
98+
99+
This corresponds to SDK's ``DictOfListsExpandInput``, which was created by
100+
calling ``expand(**kwargs)`` on an operator type.
101+
"""
102+
103+
value: dict[str, ExpandArgument]
66104

67105
EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists"
68106

@@ -90,7 +128,7 @@ def _get_map_lengths(self, run_id: str, *, session: Session) -> dict[str, int]:
90128

91129
# TODO: This initiates one database call for each XComArg. Would it be
92130
# more efficient to do one single db call and unpack the value here?
93-
def _get_length(v: OperatorExpandArgument) -> int | None:
131+
def _get_length(v: ExpandArgument) -> int | None:
94132
if isinstance(v, SchedulerXComArg):
95133
return get_task_map_length(v, run_id, session=session)
96134

@@ -123,7 +161,14 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
123161

124162
@attrs.define
125163
class SchedulerListOfDictsExpandInput:
126-
value: list
164+
"""
165+
Serialized storage of a mapped operator's mapped kwargs.
166+
167+
This corresponds to SDK's ``ListOfDictsExpandInput``, which was created by
168+
calling ``expand_kwargs(xcom_arg)`` on an operator type.
169+
"""
170+
171+
value: ExpandKwargsArgument
127172

128173
EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts"
129174

airflow-core/src/airflow/serialization/encoders.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,19 @@
5959
if TYPE_CHECKING:
6060
from dateutil.relativedelta import relativedelta
6161

62+
from airflow.sdk.definitions._internal.expandinput import ExpandInput
6263
from airflow.sdk.definitions.asset import BaseAsset
6364
from airflow.triggers.base import BaseEventTrigger
6465

6566
T = TypeVar("T")
6667

6768

69+
def encode_expand_input(var: ExpandInput) -> dict[str, Any]:
70+
from airflow.serialization.serialized_objects import BaseSerialization
71+
72+
return {"type": var.EXPAND_INPUT_TYPE, "value": BaseSerialization.serialize(var.value)}
73+
74+
6875
def encode_relativedelta(var: relativedelta) -> dict[str, Any]:
6976
"""Encode a relativedelta object."""
7077
encoded = {k: v for k, v in var.__dict__.items() if not k.startswith("_") and v}

airflow-core/src/airflow/serialization/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,4 @@ class DagAttributeTypes(str, Enum):
6969
DAG_CALLBACK_REQUEST = "dag_callback_request"
7070
TASK_INSTANCE_KEY = "task_instance_key"
7171
DEADLINE_ALERT = "deadline_alert"
72+
MAPPED_ARGUMENT = "mapped_argument"

airflow-core/src/airflow/serialization/serialized_objects.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,11 @@
4646
from airflow.callbacks.callback_requests import DagCallbackRequest, TaskCallbackRequest
4747
from airflow.exceptions import AirflowException, DeserializationError, SerializationError
4848
from airflow.models.connection import Connection
49-
from airflow.models.expandinput import create_expand_input
49+
from airflow.models.expandinput import SchedulerMappedArgument, create_expand_input
5050
from airflow.models.taskinstancekey import TaskInstanceKey
5151
from airflow.sdk import DAG, Asset, AssetAlias, BaseOperator, XComArg
5252
from airflow.sdk.bases.operator import OPERATOR_DEFAULTS # TODO: Copy this into the scheduler?
53+
from airflow.sdk.definitions._internal.expandinput import MappedArgument
5354
from airflow.sdk.definitions.asset import (
5455
AssetAliasEvent,
5556
AssetAliasUniqueKey,
@@ -81,6 +82,7 @@
8182
from airflow.serialization.encoders import (
8283
coerce_to_core_timetable,
8384
encode_asset_like,
85+
encode_expand_input,
8486
encode_relativedelta,
8587
encode_timetable,
8688
encode_timezone,
@@ -654,6 +656,9 @@ def serialize(
654656
return cls._encode(var.to_json(), type_=DAT.TASK_CALLBACK_REQUEST)
655657
elif isinstance(var, DagCallbackRequest):
656658
return cls._encode(var.to_json(), type_=DAT.DAG_CALLBACK_REQUEST)
659+
elif isinstance(var, MappedArgument):
660+
data = {"input": encode_expand_input(var._input), "key": var._key}
661+
return cls._encode(data, type_=DAT.MAPPED_ARGUMENT)
657662
elif var.__class__ == Context:
658663
d = {}
659664
for k, v in var.items():
@@ -763,6 +768,9 @@ def deserialize(cls, encoded_var: Any) -> Any:
763768
return DagCallbackRequest.from_json(var)
764769
elif type_ == DAT.TASK_INSTANCE_KEY:
765770
return TaskInstanceKey(**var)
771+
elif type_ == DAT.MAPPED_ARGUMENT:
772+
expand_input = create_expand_input(var["input"]["type"], var["input"]["value"])
773+
return SchedulerMappedArgument(input=expand_input, key=var["key"])
766774
elif type_ == DAT.ARG_NOT_SET:
767775
from airflow.serialization.definitions.notset import NOTSET
768776

@@ -1033,10 +1041,7 @@ def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]:
10331041
expansion_kwargs = op._get_specified_expand_input()
10341042
if TYPE_CHECKING: # Let Mypy check the input type for us!
10351043
_ExpandInputRef.validate_expand_input_value(expansion_kwargs.value)
1036-
serialized_op[op._expand_input_attr] = {
1037-
"type": type(expansion_kwargs).EXPAND_INPUT_TYPE,
1038-
"value": cls.serialize(expansion_kwargs.value),
1039-
}
1044+
serialized_op[op._expand_input_attr] = encode_expand_input(expansion_kwargs)
10401045

10411046
if op.partial_kwargs:
10421047
serialized_op["partial_kwargs"] = {}
@@ -2179,11 +2184,7 @@ def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None:
21792184
}
21802185

21812186
if isinstance(task_group, MappedTaskGroup):
2182-
expand_input = task_group._expand_input
2183-
encoded["expand_input"] = {
2184-
"type": expand_input.EXPAND_INPUT_TYPE,
2185-
"value": cls.serialize(expand_input.value),
2186-
}
2187+
encoded["expand_input"] = encode_expand_input(task_group._expand_input)
21872188
encoded["is_mapped"] = True
21882189

21892190
return encoded

airflow-core/tests/unit/models/test_dagrun.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,7 +1456,7 @@ def task_2(arg2): ...
14561456
task_2.expand(arg2=[1, 2])
14571457

14581458
# Update it to use the new serialized DAG
1459-
dr.dag = dag_maker.dag
1459+
dr.dag = dag_maker.serialized_model.dag
14601460
dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id, session=session).id
14611461
dr.verify_integrity(dag_version_id=dag_version_id, session=session)
14621462

@@ -1522,10 +1522,10 @@ def task_2(arg2): ...
15221522
]
15231523

15241524
# Now "increase" the length of literal
1525-
with dag_maker(session=session, serialized=True) as dag:
1525+
with dag_maker(session=session, serialized=True):
15261526
task_2.expand(arg2=[1, 2, 3, 4, 5])
15271527

1528-
dr.dag = dag
1528+
dr.dag = dag_maker.serialized_model.dag
15291529
# Every mapped task is revised at task_instance_scheduling_decision
15301530
dr.task_instance_scheduling_decisions()
15311531

@@ -1565,7 +1565,7 @@ def task_2(arg2): ...
15651565
with dag_maker(session=session):
15661566
task_2.expand(arg2=[1, 2])
15671567

1568-
dr.dag = dag_maker.dag
1568+
dr.dag = dag_maker.serialized_model.dag
15691569
# Since we change the literal on the dag file itself, the dag_hash will
15701570
# change which will have the scheduler verify the dr integrity
15711571
dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id, session=session).id

providers/standard/tests/unit/standard/decorators/test_python.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from airflow.decorators.base import DecoratedMappedOperator # type: ignore[no-redef]
4949
from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef]
5050
from airflow.models.dag import DAG # type: ignore[assignment,no-redef]
51-
from airflow.models.expandinput import DictOfListsExpandInput
51+
from airflow.models.expandinput import DictOfListsExpandInput # type: ignore[attr-defined,no-redef]
5252
from airflow.models.xcom_arg import XComArg # type: ignore[no-redef]
5353
from airflow.utils.task_group import TaskGroup # type: ignore[no-redef]
5454

task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
)
2929
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias
3030

31-
import methodtools
32-
3331
from airflow.sdk import TriggerRule, WeightRule
3432
from airflow.sdk.configuration import conf
3533
from airflow.sdk.definitions._internal.mixins import DependencyMixin
@@ -83,10 +81,6 @@
8381
log = logging.getLogger(__name__)
8482

8583

86-
class NotMapped(Exception):
87-
"""Raise if a task is neither mapped nor has any parent mapped groups."""
88-
89-
9084
class AbstractOperator(Templater, DAGNode):
9185
"""
9286
Common implementation for operators, including unmapped and mapped.
@@ -408,21 +402,3 @@ def get_needs_expansion(self) -> bool:
408402
else:
409403
self._needs_expansion = False
410404
return self._needs_expansion
411-
412-
@methodtools.lru_cache(maxsize=None)
413-
def get_parse_time_mapped_ti_count(self) -> int:
414-
"""
415-
Return the number of mapped task instances that can be created on Dag run creation.
416-
417-
This only considers literal mapped arguments, and would return *None*
418-
when any non-literal values are used for mapping.
419-
420-
:raise NotFullyPopulated: If non-literal mapped arguments are encountered.
421-
:raise NotMapped: If the operator is neither mapped, nor has any parent
422-
mapped task groups.
423-
:return: Total number of mapped TIs this task should have.
424-
"""
425-
group = self.get_closest_mapped_task_group()
426-
if group is None:
427-
raise NotMapped()
428-
return group.get_parse_time_mapped_ti_count()

task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
# under the License.
1818
from __future__ import annotations
1919

20-
import functools
21-
import operator
2220
from collections.abc import Iterable, Mapping, Sequence, Sized
2321
from typing import TYPE_CHECKING, Any, ClassVar, Union
2422

@@ -43,12 +41,13 @@
4341
OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]]
4442

4543

46-
class NotFullyPopulated(RuntimeError):
44+
class _NotFullyPopulated(RuntimeError):
4745
"""
48-
Raise when ``get_map_lengths`` cannot populate all mapping metadata.
46+
Raise when an expand input cannot be resolved due to incomplete metadata.
4947
50-
This is generally due to not all upstream tasks have finished when the
51-
function is called.
48+
This generally should not happen. The scheduler should have made sure that
49+
a not-yet-ready-to-expand task should not be executed. In the off chance
50+
this gets raised, it will fail the task instance.
5251
"""
5352

5453
def __init__(self, missing: set[str]) -> None:
@@ -123,15 +122,6 @@ def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
123122
"""Generate kwargs with values available on parse-time."""
124123
return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v))
125124

126-
def get_parse_time_mapped_ti_count(self) -> int:
127-
if not self.value:
128-
return 0
129-
literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()]
130-
if len(literal_values) != len(self.value):
131-
literal_keys = (k for k, _ in self._iter_parse_time_resolved_kwargs())
132-
raise NotFullyPopulated(set(self.value).difference(literal_keys))
133-
return functools.reduce(operator.mul, literal_values, 1)
134-
135125
def _get_map_lengths(
136126
self, resolved_vals: dict[str, Sized], upstream_map_indexes: dict[str, int]
137127
) -> dict[str, int]:
@@ -160,7 +150,7 @@ def _get_length(k: str, v: OperatorExpandArgument) -> int | None:
160150
k: res for k, v in self.value.items() if v is not None if (res := _get_length(k, v)) is not None
161151
}
162152
if len(map_lengths) < len(self.value):
163-
raise NotFullyPopulated(set(self.value).difference(map_lengths))
153+
raise _NotFullyPopulated(set(self.value).difference(map_lengths))
164154
return map_lengths
165155

166156
def _expand_mapped_field(self, key: str, value: Any, map_index: int, all_lengths: dict[str, int]) -> Any:
@@ -236,11 +226,6 @@ class ListOfDictsExpandInput(ResolveMixin):
236226

237227
EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts"
238228

239-
def get_parse_time_mapped_ti_count(self) -> int:
240-
if isinstance(self.value, Sized):
241-
return len(self.value)
242-
raise NotFullyPopulated({"expand_kwargs() argument"})
243-
244229
def iter_references(self) -> Iterable[tuple[Operator, str]]:
245230
from airflow.sdk.definitions.xcom_arg import XComArg
246231

0 commit comments

Comments
 (0)