Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,25 @@ Since the template is rendered after the main execution block, it is possible to
# The task instances will be named "aaa" and "bbb".
my_task.expand(my_value=["a", "b"])
Named mapping for task groups
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

When using mapped task groups, you can set ``map_index_template`` on the ``@task_group`` decorator. This makes the group's expansion arguments available in the rendering context for **all** child tasks in the group, not just the first one that receives the argument directly.

.. code-block:: python
@task_group(map_index_template="{{ filename }}")
def file_transforms(filename):
extracted = extract(filename)
load(extracted)
file_transforms.expand(filename=["data1.json", "data2.json"])
In this example, both ``extract`` and ``load`` task instances will be labeled "data1.json" or "data2.json" based on the group's expansion argument. Without the group-level template, only ``extract`` would have access to ``filename`` in its rendering context, while ``load`` would only see ``extracted``.

If a child task also defines its own ``map_index_template``, the task-level template takes precedence over the group-level one.

Mapping with non-TaskFlow operators
===================================

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,12 @@ class TIRunContext(BaseModel):
should_retry: bool = False
"""If the ti encounters an error, whether it should enter retry or failed state."""

task_group_map_index_template: str | None = None
"""map_index_template from parent MappedTaskGroup, if any."""

task_group_expanded_args: dict[str, Any] | None = None
"""Resolved expansion arguments from parent MappedTaskGroup for this specific map_index."""


class PrevSuccessfulDagRunResponse(BaseModel):
"""Schema for response with previous successful DagRun information for Task Template Context."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,11 @@ def ti_run(
context.next_method = ti.next_method
context.next_kwargs = ti.next_kwargs

if ti.map_index >= 0:
_populate_task_group_map_index_context(
context, ti.dag_id, ti.task_id, ti.map_index, ti.run_id, session, dag_bag
)

return context
except SQLAlchemyError:
log.exception("Error marking Task Instance state as running")
Expand Down Expand Up @@ -921,6 +926,100 @@ def _iter_breadcrumbs() -> Iterator[dict[str, Any]]:
return TaskBreadcrumbsResponse(breadcrumbs=_iter_breadcrumbs())


def _populate_task_group_map_index_context(
context: TIRunContext,
dag_id: str,
task_id: str,
map_index: int,
run_id: str,
session: SessionDep,
dag_bag: DagBagDep,
) -> None:
"""Populate task group map_index_template and expanded args on the TIRunContext."""
try:
dag = get_latest_version_of_dag(dag_bag, dag_id, session)
except HTTPException:
return

task = dag.task_dict.get(task_id)
if not task:
return

# iter_mapped_task_groups walks from innermost to outermost; we use the first match.
for mtg in task.iter_mapped_task_groups():
if not mtg.map_index_template:
continue

context.task_group_map_index_template = mtg.map_index_template
context.task_group_expanded_args = _resolve_task_group_expand_args(
mtg._expand_input, map_index, run_id, session
)
break


def _resolve_task_group_expand_args(
expand_input: Any,
map_index: int,
run_id: str,
session: SessionDep,
) -> dict[str, Any] | None:
"""Resolve the expand_input for a specific map_index to get the expanded arguments."""
from airflow.models.expandinput import SchedulerDictOfListsExpandInput, SchedulerListOfDictsExpandInput
from airflow.serialization.definitions.xcom_arg import SchedulerXComArg

def _resolve_at_index(value: Any) -> Any | None:
"""Resolve a single value (list/tuple or XComArg) at the given map_index."""
match value:
case SchedulerXComArg():
value = _resolve_xcom_arg_value(value, run_id, session)
case list() | tuple():
pass
case _:
return None
if isinstance(value, (list, tuple)) and map_index < len(value):
return value[map_index]
return None

match expand_input:
case SchedulerDictOfListsExpandInput(value=mapping):
resolved = {}
for key, val in mapping.items():
if (item := _resolve_at_index(val)) is not None:
resolved[key] = item
return resolved or None

case SchedulerListOfDictsExpandInput(value=val):
if isinstance(item := _resolve_at_index(val), dict):
return item

return None


def _resolve_xcom_arg_value(xcom_arg: Any, run_id: str, session: SessionDep) -> Any:
"""Resolve a SchedulerXComArg to its actual value via XCom query."""
refs = list(xcom_arg.iter_references())
if not refs:
return None
operator, key = refs[0]

xcom_value = session.scalar(
select(XComModel.value).where(
XComModel.dag_id == operator.dag_id,
XComModel.task_id == operator.task_id,
XComModel.run_id == run_id,
XComModel.key == key,
XComModel.map_index == -1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean the group's expand input can't depend on mapped upstream task outputs? If so, this should be documented as a limitation.

Copy link
Contributor Author

@anishgirianish anishgirianish Feb 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question! From what I can tell, the upstream producing the expand input is unmapped; it returns a list that drives the mapping. If it were mapped, it would go through expand_kwargs instead, so it wouldn't be a SchedulerXComArg here. But I'd love to hear if you've seen a case where this breaks,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not that familiar with xcom intricacies - perhaps it's worth adding a test that ensures this scenario works correctly.

)
)
if xcom_value is None:
return None
try:
return json.loads(xcom_value)
except (json.JSONDecodeError, TypeError):
log.debug("Failed to decode XCom value for task_group expand args", exc_info=True)
return None


def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool:
"""Is task instance is eligible for retry."""
if state == TaskInstanceState.RESTARTING:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@
ModifyDeferredTaskKwargsToJsonValue,
RemoveUpstreamMapIndexesField,
)
from airflow.api_fastapi.execution_api.versions.v2026_04_15 import AddTaskGroupMapIndexTemplateFields

bundle = VersionBundle(
HeadVersion(),
Version("2026-04-15", AddTaskGroupMapIndexTemplateFields),
Version("2026-03-31", ModifyDeferredTaskKwargsToJsonValue, RemoveUpstreamMapIndexesField),
Version("2025-12-08", MovePreviousRunEndpoint, AddDagRunDetailEndpoint),
Version("2025-11-07", AddPartitionKeyField),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from cadwyn import ResponseInfo, VersionChange, convert_response_to_previous_version_for, schema

from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRunContext


class AddTaskGroupMapIndexTemplateFields(VersionChange):
"""Add task_group_map_index_template and task_group_expanded_args fields to TIRunContext."""

description = __doc__

instructions_to_migrate_to_previous_version = (
schema(TIRunContext).field("task_group_map_index_template").didnt_exist,
schema(TIRunContext).field("task_group_expanded_args").didnt_exist,
)

@convert_response_to_previous_version_for(TIRunContext) # type: ignore[arg-type]
def remove_task_group_fields(response: ResponseInfo) -> None: # type: ignore[misc]
"""Remove task group map index fields for older API versions."""
response.body.pop("task_group_map_index_template", None)
response.body.pop("task_group_expanded_args", None)
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class SerializedTaskGroup(DAGNode):
parent_group: SerializedTaskGroup | None = attrs.field()
dag: SerializedDAG = attrs.field()
tooltip: str = attrs.field()
map_index_template: str | None = attrs.field(default=None)
default_args: dict[str, Any] = attrs.field(factory=dict)

# TODO: Are these actually useful?
Expand Down
1 change: 1 addition & 0 deletions airflow-core/src/airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@
"tooltip": { "type": "string" },
"ui_color": { "type": "string" },
"ui_fgcolor": { "type": "string" },
"map_index_template": { "type": ["null", "string"] },
"upstream_group_ids": {
"type": "array",
"items": { "type": "string" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2088,6 +2088,7 @@ def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None:
"tooltip": task_group.tooltip,
"ui_color": task_group.ui_color,
"ui_fgcolor": task_group.ui_fgcolor,
"map_index_template": task_group.map_index_template,
"children": {
label: child.serialize_for_task_group() for label, child in task_group.children.items()
},
Expand Down Expand Up @@ -2118,6 +2119,7 @@ def deserialize_task_group(
for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"]
}
kwargs["group_display_name"] = cls.deserialize(encoded_group.get("group_display_name", ""))
kwargs["map_index_template"] = cls.deserialize(encoded_group.get("map_index_template"))

if not encoded_group.get("is_mapped"):
group = SerializedTaskGroup(group_id=group_id, parent_group=parent_group, dag=dag, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,88 @@ def task_3():
)
assert response.status_code == 200

def test_ti_run_includes_task_group_map_index_template(
self, client: Client, dag_maker: DagMaker, session: Session
):
"""Test that ti_run includes task_group_map_index_template and expanded args for mapped task groups."""
with dag_maker("test_tg_map_index_template", serialized=True):

@task_group(map_index_template="{{ filename }}")
def tg(filename):
@task
def process(filename):
return filename

process(filename)

tg.expand(filename=["a.json", "b.json"])

dr = dag_maker.create_dagrun()

# Set all task instances to queued for the mapped tasks
for ti in dr.get_task_instances():
if ti.map_index >= 0:
ti.set_state(State.QUEUED)
session.flush()

# Get the first mapped task instance
mapped_tis = [ti for ti in dr.get_task_instances() if ti.map_index >= 0]
assert len(mapped_tis) > 0

ti = mapped_tis[0]
response = client.patch(
f"/execution/task-instances/{ti.id}/run",
json={
"state": "running",
"hostname": "random-hostname",
"unixname": "random-unixname",
"pid": 100,
"start_date": "2024-09-30T12:00:00Z",
},
)

assert response.status_code == 200
data = response.json()

assert data.get("task_group_map_index_template") == "{{ filename }}"
assert data.get("task_group_expanded_args") == {"filename": "a.json"}

def test_ti_run_no_task_group_fields_for_non_mapped(
self, client, session, create_task_instance, time_machine
):
"""Test that task_group fields are not set for non-mapped task instances."""
instant_str = "2024-09-30T12:00:00Z"
instant = timezone.parse(instant_str)
time_machine.move_to(instant, tick=False)

ti = create_task_instance(
task_id="test_non_mapped",
state=State.QUEUED,
dagrun_state=DagRunState.RUNNING,
session=session,
start_date=instant,
dag_id=str(uuid4()),
)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/run",
json={
"state": "running",
"hostname": "random-hostname",
"unixname": "random-unixname",
"pid": 100,
"start_date": instant_str,
},
)

assert response.status_code == 200
data = response.json()

# Non-mapped tasks should not have task group fields
assert "task_group_map_index_template" not in data
assert "task_group_expanded_args" not in data

def test_dynamic_task_mapping_with_all_success_trigger_rule(self, dag_maker: DagMaker, session: Session):
"""Test that with ALL_SUCCESS trigger rule and skipped upstream, downstream should not run."""

Expand Down
27 changes: 27 additions & 0 deletions airflow-core/tests/unit/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def _operator_defaults(overrides):
"tooltip": "",
"ui_color": "CornflowerBlue",
"ui_fgcolor": "#000",
"map_index_template": None,
"upstream_group_ids": [],
"downstream_group_ids": [],
"upstream_task_ids": [],
Expand Down Expand Up @@ -1714,6 +1715,30 @@ def check_task_group(node):

check_task_group(serialized_dag.task_group)

def test_task_group_map_index_template_serialization(self):
"""Test that map_index_template on TaskGroup survives serialization round-trip."""
from airflow.providers.standard.operators.empty import EmptyOperator

with DAG("test_tg_map_index_template", schedule=None, start_date=datetime(2020, 1, 1)) as dag:
with TaskGroup("my_group", map_index_template="{{ filename }}"):
EmptyOperator(task_id="task1")

serialized_dag = DagSerialization.deserialize_dag(DagSerialization.serialize_dag(dag))
tg = serialized_dag.task_group.children["my_group"]
assert tg.map_index_template == "{{ filename }}"

def test_task_group_map_index_template_none_by_default(self):
"""Test that map_index_template defaults to None when not set."""
from airflow.providers.standard.operators.empty import EmptyOperator

with DAG("test_tg_mit_default", schedule=None, start_date=datetime(2020, 1, 1)) as dag:
with TaskGroup("my_group"):
EmptyOperator(task_id="task1")

serialized_dag = DagSerialization.deserialize_dag(DagSerialization.serialize_dag(dag))
tg = serialized_dag.task_group.children["my_group"]
assert tg.map_index_template is None

@staticmethod
def assert_taskgroup_children(se_task_group, dag_task_group, expected_children):
assert se_task_group.children.keys() == dag_task_group.children.keys() == expected_children
Expand Down Expand Up @@ -3100,6 +3125,7 @@ def tg(a: str) -> None:
},
"group_display_name": "",
"is_mapped": True,
"map_index_template": None,
"prefix_group_id": True,
"tooltip": "",
"ui_color": "CornflowerBlue",
Expand All @@ -3113,6 +3139,7 @@ def tg(a: str) -> None:
serde_tg = serde_dag.task_group.children["tg"]
assert isinstance(serde_tg, SerializedTaskGroup)
assert serde_tg._expand_input == SchedulerDictOfListsExpandInput({"a": [".", ".."]})
assert serde_tg.map_index_template is None


@pytest.mark.db_test
Expand Down
4 changes: 3 additions & 1 deletion task-sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, RootModel

API_VERSION: Final[str] = "2026-03-31"
API_VERSION: Final[str] = "2026-04-15"


class AssetAliasReferenceAssetEventDagRun(BaseModel):
Expand Down Expand Up @@ -644,3 +644,5 @@ class TIRunContext(BaseModel):
next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next Kwargs")] = None
xcom_keys_to_clear: Annotated[list[str] | None, Field(title="Xcom Keys To Clear")] = None
should_retry: Annotated[bool | None, Field(title="Should Retry")] = False
task_group_map_index_template: Annotated[str | None, Field(title="Task Group Map Index Template")] = None
task_group_expanded_args: Annotated[dict[str, Any] | None, Field(title="Task Group Expanded Args")] = None
Loading
Loading