Skip to content

Commit eec2743

Browse files
authored
Move MappedOperator to serialization (#59628)
1 parent 01fe16d commit eec2743

File tree

29 files changed

+115
-151
lines changed

29 files changed

+115
-151
lines changed

airflow-core/src/airflow/api/common/mark_tasks.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from __future__ import annotations
2121

2222
from collections.abc import Collection, Iterable, Iterator
23-
from typing import TYPE_CHECKING, TypeAlias
23+
from typing import TYPE_CHECKING
2424

2525
from sqlalchemy import and_, or_, select
2626
from sqlalchemy.orm import lazyload
@@ -33,11 +33,8 @@
3333
if TYPE_CHECKING:
3434
from sqlalchemy.orm import Session as SASession
3535

36-
from airflow.models.mappedoperator import MappedOperator
37-
from airflow.serialization.definitions.baseoperator import SerializedBaseOperator
3836
from airflow.serialization.definitions.dag import SerializedDAG
39-
40-
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
37+
from airflow.serialization.definitions.mappedoperator import Operator
4138

4239

4340
@provide_session

airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424

2525
from airflow.api_fastapi.common.parameters import state_priority
2626
from airflow.api_fastapi.core_api.services.ui.task_group import get_task_group_children_getter
27-
from airflow.models.mappedoperator import MappedOperator
2827
from airflow.models.taskmap import TaskMap
2928
from airflow.serialization.definitions.baseoperator import SerializedBaseOperator
29+
from airflow.serialization.definitions.mappedoperator import SerializedMappedOperator
3030
from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
3131

3232
log = structlog.get_logger(logger_name=__name__)
@@ -90,7 +90,7 @@ def _find_aggregates(
9090

9191
if node is None:
9292
return
93-
if isinstance(node, MappedOperator):
93+
if isinstance(node, SerializedMappedOperator):
9494
# For unmapped tasks, reflect a single None state so UI shows one square
9595
mapped_details = details or [{"state": None, "start_date": None, "end_date": None}]
9696
yield {

airflow-core/src/airflow/api_fastapi/core_api/services/ui/task_group.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
from operator import methodcaller
2525

2626
from airflow.configuration import conf
27-
from airflow.models.mappedoperator import MappedOperator, is_mapped
2827
from airflow.serialization.definitions.baseoperator import SerializedBaseOperator
28+
from airflow.serialization.definitions.mappedoperator import SerializedMappedOperator, is_mapped
2929

3030

3131
@cache
@@ -39,7 +39,7 @@ def get_task_group_children_getter() -> Callable:
3939

4040
def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False):
4141
"""Create a nested dict representation of this TaskGroup and its children used to construct the Graph."""
42-
if isinstance(task := task_item_or_group, (SerializedBaseOperator, MappedOperator)):
42+
if isinstance(task := task_item_or_group, (SerializedBaseOperator, SerializedMappedOperator)):
4343
# we explicitly want the short task ID here, not the full doted notation if in a group
4444
task_display_name = task.task_display_name if task.task_display_name != task.task_id else task.label
4545
node_operator = {
@@ -83,7 +83,7 @@ def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False):
8383

8484
def task_group_to_dict_grid(task_item_or_group, parent_group_is_mapped=False):
8585
"""Create a nested dict representation of this TaskGroup and its children used to construct the Grid."""
86-
if isinstance(task := task_item_or_group, (MappedOperator, SerializedBaseOperator)):
86+
if isinstance(task := task_item_or_group, (SerializedMappedOperator, SerializedBaseOperator)):
8787
mapped = None
8888
if parent_group_is_mapped or is_mapped(task):
8989
mapped = True

airflow-core/src/airflow/cli/commands/task_command.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,9 @@
6262

6363
from sqlalchemy.orm.session import Session
6464

65-
from airflow.models.mappedoperator import MappedOperator
66-
from airflow.serialization.definitions.baseoperator import SerializedBaseOperator
65+
from airflow.serialization.definitions.mappedoperator import Operator
6766

6867
CreateIfNecessary = Literal[False, "db", "memory"]
69-
Operator = MappedOperator | SerializedBaseOperator
7068

7169
log = logging.getLogger(__name__)
7270

airflow-core/src/airflow/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __getattr__(name):
109109
"Deadline": "airflow.models.deadline",
110110
"Log": "airflow.models.log",
111111
"HITLDetail": "airflow.models.hitl",
112-
"MappedOperator": "airflow.models.mappedoperator",
112+
"MappedOperator": "airflow.sdk.definitions.mappedoperator",
113113
"Param": "airflow.sdk.definitions.param",
114114
"Pool": "airflow.models.pool",
115115
"RenderedTaskInstanceFields": "airflow.models.renderedtifields",
@@ -136,7 +136,6 @@ def __getattr__(name):
136136
from airflow.models.db_callback_request import DbCallbackRequest
137137
from airflow.models.deadline import Deadline
138138
from airflow.models.log import Log
139-
from airflow.models.mappedoperator import MappedOperator
140139
from airflow.models.pool import Pool
141140
from airflow.models.renderedtifields import RenderedTaskInstanceFields
142141
from airflow.models.skipmixin import SkipMixin
@@ -147,6 +146,7 @@ def __getattr__(name):
147146
from airflow.models.variable import Variable
148147
from airflow.sdk import DAG, BaseOperator, BaseOperatorLink, Param
149148
from airflow.sdk.bases.xcom import BaseXCom
149+
from airflow.sdk.definitions.mappedoperator import MappedOperator
150150
from airflow.sdk.execution_time.xcom import XCom
151151

152152

airflow-core/src/airflow/models/dag.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,13 @@
7070
if TYPE_CHECKING:
7171
from typing import TypeAlias
7272

73-
from airflow.models.mappedoperator import MappedOperator
7473
from airflow.serialization.definitions.assets import (
7574
SerializedAsset,
7675
SerializedAssetAlias,
7776
SerializedAssetBase,
7877
)
79-
from airflow.serialization.definitions.baseoperator import SerializedBaseOperator
8078
from airflow.serialization.definitions.dag import SerializedDAG
8179

82-
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
8380
UKey: TypeAlias = SerializedAssetUniqueKey
8481

8582
log = logging.getLogger(__name__)

airflow-core/src/airflow/models/dagrun.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,15 @@
104104

105105
from airflow._shared.observability.traces.base_tracer import EmptySpan
106106
from airflow.models.dag_version import DagVersion
107-
from airflow.models.mappedoperator import MappedOperator
108107
from airflow.models.taskinstancekey import TaskInstanceKey
109108
from airflow.sdk import DAG as SDKDAG
110-
from airflow.serialization.definitions.baseoperator import SerializedBaseOperator
111109
from airflow.serialization.definitions.dag import SerializedDAG
110+
from airflow.serialization.definitions.mappedoperator import Operator
112111

113112
CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI])
114113
AttributeValueType: TypeAlias = (
115114
str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]
116115
)
117-
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
118116

119117
RUN_ID_REGEX = r"^(?:manual|scheduled|asset_triggered)__(?:\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00)$"
120118

@@ -1529,7 +1527,7 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
15291527
If the ti does not need expansion, either because the task is not
15301528
mapped, or has already been expanded, *None* is returned.
15311529
"""
1532-
from airflow.models.mappedoperator import is_mapped
1530+
from airflow.serialization.definitions.mappedoperator import is_mapped
15331531

15341532
if TYPE_CHECKING:
15351533
assert ti.task
@@ -1749,7 +1747,7 @@ def _check_for_removed_or_restored_tasks(
17491747
17501748
"""
17511749
from airflow.models.expandinput import NotFullyPopulated
1752-
from airflow.models.mappedoperator import get_mapped_ti_count
1750+
from airflow.serialization.definitions.mappedoperator import get_mapped_ti_count
17531751

17541752
tis = self.get_task_instances(session=session)
17551753

@@ -1891,7 +1889,7 @@ def _create_tasks(
18911889
:param task_creator: Function to create task instances
18921890
"""
18931891
from airflow.models.expandinput import NotFullyPopulated
1894-
from airflow.models.mappedoperator import get_mapped_ti_count
1892+
from airflow.serialization.definitions.mappedoperator import get_mapped_ti_count
18951893

18961894
map_indexes: Iterable[int]
18971895
for task in tasks:
@@ -1965,7 +1963,7 @@ def _revise_map_indexes_if_mapped(
19651963
for more details.
19661964
"""
19671965
from airflow.models.expandinput import NotFullyPopulated
1968-
from airflow.models.mappedoperator import get_mapped_ti_count
1966+
from airflow.serialization.definitions.mappedoperator import get_mapped_ti_count
19691967
from airflow.settings import task_instance_mutation_hook
19701968

19711969
try:

airflow-core/src/airflow/models/expandinput.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,13 @@
3535
)
3636

3737
if TYPE_CHECKING:
38-
from typing import TypeAlias, TypeGuard
38+
from typing import TypeGuard
3939

4040
from sqlalchemy.orm import Session
4141

42-
from airflow.models.mappedoperator import MappedOperator
43-
from airflow.serialization.definitions.baseoperator import SerializedBaseOperator
42+
from airflow.serialization.definitions.mappedoperator import Operator
4443
from airflow.serialization.definitions.xcom_arg import SchedulerXComArg
4544

46-
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
47-
4845

4946
__all__ = [
5047
"DictOfListsExpandInput",

airflow-core/src/airflow/models/referencemixin.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,8 @@
2121

2222
if TYPE_CHECKING:
2323
from collections.abc import Iterable
24-
from typing import TypeAlias
2524

26-
from airflow.models.mappedoperator import MappedOperator
27-
from airflow.serialization.definitions.baseoperator import SerializedBaseOperator
28-
29-
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
25+
from airflow.serialization.definitions.mappedoperator import Operator
3026

3127

3228
@runtime_checkable

airflow-core/src/airflow/models/taskinstance.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@
103103

104104
if TYPE_CHECKING:
105105
from datetime import datetime
106-
from typing import Literal, TypeAlias
106+
from typing import Literal
107107

108108
import pendulum
109109
from sqlalchemy.engine import Connection as SAConnection, Engine
@@ -114,14 +114,11 @@
114114
from airflow.api_fastapi.execution_api.datamodels.asset import AssetProfile
115115
from airflow.models.dag import DagModel
116116
from airflow.models.dagrun import DagRun
117-
from airflow.models.mappedoperator import MappedOperator
118-
from airflow.serialization.definitions.baseoperator import SerializedBaseOperator
119117
from airflow.serialization.definitions.dag import SerializedDAG
118+
from airflow.serialization.definitions.mappedoperator import Operator
120119
from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
121120
from airflow.utils.context import Context
122121

123-
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
124-
125122

126123
PAST_DEPENDS_MET = "past_depends_met"
127124

@@ -1660,7 +1657,6 @@ def get_template_context(
16601657
session = settings.get_session()()
16611658

16621659
from airflow.exceptions import NotMapped
1663-
from airflow.models.mappedoperator import get_mapped_ti_count
16641660
from airflow.sdk.api.datamodels._generated import (
16651661
DagRun as DagRunSDK,
16661662
PrevSuccessfulDagRunResponse,
@@ -1669,6 +1665,7 @@ def get_template_context(
16691665
from airflow.sdk.definitions.param import process_params
16701666
from airflow.sdk.execution_time.context import InletEventsAccessors
16711667
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
1668+
from airflow.serialization.definitions.mappedoperator import get_mapped_ti_count
16721669
from airflow.utils.context import (
16731670
ConnectionAccessor,
16741671
OutletEventAccessors,
@@ -2193,7 +2190,7 @@ def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> Seri
21932190

21942191
def _is_further_mapped_inside(operator: Operator, container: SerializedTaskGroup) -> bool:
21952192
"""Whether given operator is *further* mapped inside a task group."""
2196-
from airflow.models.mappedoperator import is_mapped
2193+
from airflow.serialization.definitions.mappedoperator import is_mapped
21972194

21982195
if is_mapped(operator):
21992196
return True
@@ -2260,7 +2257,7 @@ def tg2(inp):
22602257
:return: Specific map index or map indexes to pull, or ``None`` if we
22612258
want to "whole" return value (i.e. no mapped task groups involved).
22622259
"""
2263-
from airflow.models.mappedoperator import get_mapped_ti_count
2260+
from airflow.serialization.definitions.mappedoperator import get_mapped_ti_count
22642261

22652262
# This value should never be None since we already know the current task
22662263
# is in a mapped task group, and should have been expanded, despite that,
@@ -2307,7 +2304,7 @@ def find_relevant_relatives(
23072304
run_id: str,
23082305
session: Session,
23092306
) -> Collection[str | tuple[str, int]]:
2310-
from airflow.models.mappedoperator import get_mapped_ti_count
2307+
from airflow.serialization.definitions.mappedoperator import get_mapped_ti_count
23112308

23122309
visited: set[str | tuple[str, int]] = set()
23132310

0 commit comments

Comments
 (0)