diff --git a/airflow-core/docs/authoring-and-scheduling/dynamic-task-mapping.rst b/airflow-core/docs/authoring-and-scheduling/dynamic-task-mapping.rst index 1a5366cc4c374..40141274b2146 100644 --- a/airflow-core/docs/authoring-and-scheduling/dynamic-task-mapping.rst +++ b/airflow-core/docs/authoring-and-scheduling/dynamic-task-mapping.rst @@ -217,6 +217,25 @@ Since the template is rendered after the main execution block, it is possible to # The task instances will be named "aaa" and "bbb". my_task.expand(my_value=["a", "b"]) +Named mapping for task groups +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When using mapped task groups, you can set ``map_index_template`` on the ``@task_group`` decorator. This makes the group's expansion arguments available in the rendering context for **all** child tasks in the group, not just the first one that receives the argument directly. + +.. code-block:: python + + @task_group(map_index_template="{{ filename }}") + def file_transforms(filename): + extracted = extract(filename) + load(extracted) + + + file_transforms.expand(filename=["data1.json", "data2.json"]) + +In this example, both ``extract`` and ``load`` task instances will be labeled "data1.json" or "data2.json" based on the group's expansion argument. Without the group-level template, only ``extract`` would have access to ``filename`` in its rendering context, while ``load`` would only see ``extracted``. + +If a child task also defines its own ``map_index_template``, the task-level template takes precedence over the group-level one. + Mapping with non-TaskFlow operators =================================== diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 513a99f6dc996..fdb27bc6fa83d 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -339,6 +339,12 @@ class TIRunContext(BaseModel): should_retry: bool = False """If the ti encounters an error, whether it should enter retry or failed state.""" + task_group_map_index_template: str | None = None + """map_index_template from parent MappedTaskGroup, if any.""" + + task_group_expanded_args: dict[str, Any] | None = None + """Resolved expansion arguments from parent MappedTaskGroup for this specific map_index.""" + class PrevSuccessfulDagRunResponse(BaseModel): """Schema for response with previous successful DagRun information for Task Template Context.""" diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index f22d7c125853d..a2d92e49c4082 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -264,6 +264,11 @@ def ti_run( context.next_method = ti.next_method context.next_kwargs = ti.next_kwargs + if ti.map_index >= 0: + _populate_task_group_map_index_context( + context, ti.dag_id, ti.task_id, ti.map_index, ti.run_id, session, dag_bag + ) + return context except SQLAlchemyError: log.exception("Error marking Task Instance state as running") @@ -921,6 +926,100 @@ def _iter_breadcrumbs() -> Iterator[dict[str, Any]]: return TaskBreadcrumbsResponse(breadcrumbs=_iter_breadcrumbs()) +def _populate_task_group_map_index_context( + context: TIRunContext, + dag_id: str, + task_id: str, + map_index: int, + run_id: str, + session: SessionDep, + dag_bag: DagBagDep, +) -> None: + """Populate task group map_index_template and expanded args on the TIRunContext.""" + try: + dag = get_latest_version_of_dag(dag_bag, dag_id, session) + except HTTPException: + return + + task = dag.task_dict.get(task_id) + if not task: + return + + # iter_mapped_task_groups walks from innermost to outermost; we use the first match. + for mtg in task.iter_mapped_task_groups(): + if not mtg.map_index_template: + continue + + context.task_group_map_index_template = mtg.map_index_template + context.task_group_expanded_args = _resolve_task_group_expand_args( + mtg._expand_input, map_index, run_id, session + ) + break + + +def _resolve_task_group_expand_args( + expand_input: Any, + map_index: int, + run_id: str, + session: SessionDep, +) -> dict[str, Any] | None: + """Resolve the expand_input for a specific map_index to get the expanded arguments.""" + from airflow.models.expandinput import SchedulerDictOfListsExpandInput, SchedulerListOfDictsExpandInput + from airflow.serialization.definitions.xcom_arg import SchedulerXComArg + + def _resolve_at_index(value: Any) -> Any | None: + """Resolve a single value (list/tuple or XComArg) at the given map_index.""" + match value: + case SchedulerXComArg(): + value = _resolve_xcom_arg_value(value, run_id, session) + case list() | tuple(): + pass + case _: + return None + if isinstance(value, (list, tuple)) and map_index < len(value): + return value[map_index] + return None + + match expand_input: + case SchedulerDictOfListsExpandInput(value=mapping): + resolved = {} + for key, val in mapping.items(): + if (item := _resolve_at_index(val)) is not None: + resolved[key] = item + return resolved or None + + case SchedulerListOfDictsExpandInput(value=val): + if isinstance(item := _resolve_at_index(val), dict): + return item + + return None + + +def _resolve_xcom_arg_value(xcom_arg: Any, run_id: str, session: SessionDep) -> Any: + """Resolve a SchedulerXComArg to its actual value via XCom query.""" + refs = list(xcom_arg.iter_references()) + if not refs: + return None + operator, key = refs[0] + + xcom_value = session.scalar( + select(XComModel.value).where( + XComModel.dag_id == operator.dag_id, + XComModel.task_id == operator.task_id, + XComModel.run_id == run_id, + XComModel.key == key, + XComModel.map_index == -1, + ) + ) + if xcom_value is None: + return None + try: + return json.loads(xcom_value) + except (json.JSONDecodeError, TypeError): + log.debug("Failed to decode XCom value for task_group expand args", exc_info=True) + return None + + def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool: """Is task instance is eligible for retry.""" if state == TaskInstanceState.RESTARTING: diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index 30d4159f7453c..23a609c3ac4ad 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -37,9 +37,11 @@ ModifyDeferredTaskKwargsToJsonValue, RemoveUpstreamMapIndexesField, ) +from airflow.api_fastapi.execution_api.versions.v2026_04_15 import AddTaskGroupMapIndexTemplateFields bundle = VersionBundle( HeadVersion(), + Version("2026-04-15", AddTaskGroupMapIndexTemplateFields), Version("2026-03-31", ModifyDeferredTaskKwargsToJsonValue, RemoveUpstreamMapIndexesField), Version("2025-12-08", MovePreviousRunEndpoint, AddDagRunDetailEndpoint), Version("2025-11-07", AddPartitionKeyField), diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_15.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_15.py new file mode 100644 index 0000000000000..3ba13efabb0c8 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_15.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from cadwyn import ResponseInfo, VersionChange, convert_response_to_previous_version_for, schema + +from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRunContext + + +class AddTaskGroupMapIndexTemplateFields(VersionChange): + """Add task_group_map_index_template and task_group_expanded_args fields to TIRunContext.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + schema(TIRunContext).field("task_group_map_index_template").didnt_exist, + schema(TIRunContext).field("task_group_expanded_args").didnt_exist, + ) + + @convert_response_to_previous_version_for(TIRunContext) # type: ignore[arg-type] + def remove_task_group_fields(response: ResponseInfo) -> None: # type: ignore[misc] + """Remove task group map index fields for older API versions.""" + response.body.pop("task_group_map_index_template", None) + response.body.pop("task_group_expanded_args", None) diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py b/airflow-core/src/airflow/serialization/definitions/taskgroup.py index d971c303c7c53..31c11e4e02a3f 100644 --- a/airflow-core/src/airflow/serialization/definitions/taskgroup.py +++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py @@ -48,6 +48,7 @@ class SerializedTaskGroup(DAGNode): parent_group: SerializedTaskGroup | None = attrs.field() dag: SerializedDAG = attrs.field() tooltip: str = attrs.field() + map_index_template: str | None = attrs.field(default=None) default_args: dict[str, Any] = attrs.field(factory=dict) # TODO: Are these actually useful? diff --git a/airflow-core/src/airflow/serialization/schema.json b/airflow-core/src/airflow/serialization/schema.json index 6058275f35c05..9a8df7af9ef56 100644 --- a/airflow-core/src/airflow/serialization/schema.json +++ b/airflow-core/src/airflow/serialization/schema.json @@ -371,6 +371,7 @@ "tooltip": { "type": "string" }, "ui_color": { "type": "string" }, "ui_fgcolor": { "type": "string" }, + "map_index_template": { "type": ["null", "string"] }, "upstream_group_ids": { "type": "array", "items": { "type": "string" } diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index f1e5813b1893e..8420b3f7a24e9 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -2088,6 +2088,7 @@ def serialize_task_group(cls, task_group: TaskGroup) -> dict[str, Any] | None: "tooltip": task_group.tooltip, "ui_color": task_group.ui_color, "ui_fgcolor": task_group.ui_fgcolor, + "map_index_template": task_group.map_index_template, "children": { label: child.serialize_for_task_group() for label, child in task_group.children.items() }, @@ -2118,6 +2119,7 @@ def deserialize_task_group( for key in ["prefix_group_id", "tooltip", "ui_color", "ui_fgcolor"] } kwargs["group_display_name"] = cls.deserialize(encoded_group.get("group_display_name", "")) + kwargs["map_index_template"] = cls.deserialize(encoded_group.get("map_index_template")) if not encoded_group.get("is_mapped"): group = SerializedTaskGroup(group_id=group_id, parent_group=parent_group, dag=dag, **kwargs) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 2eba8a045a82b..77ca62a4034e2 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -396,6 +396,88 @@ def task_3(): ) assert response.status_code == 200 + def test_ti_run_includes_task_group_map_index_template( + self, client: Client, dag_maker: DagMaker, session: Session + ): + """Test that ti_run includes task_group_map_index_template and expanded args for mapped task groups.""" + with dag_maker("test_tg_map_index_template", serialized=True): + + @task_group(map_index_template="{{ filename }}") + def tg(filename): + @task + def process(filename): + return filename + + process(filename) + + tg.expand(filename=["a.json", "b.json"]) + + dr = dag_maker.create_dagrun() + + # Set all task instances to queued for the mapped tasks + for ti in dr.get_task_instances(): + if ti.map_index >= 0: + ti.set_state(State.QUEUED) + session.flush() + + # Get the first mapped task instance + mapped_tis = [ti for ti in dr.get_task_instances() if ti.map_index >= 0] + assert len(mapped_tis) > 0 + + ti = mapped_tis[0] + response = client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": "2024-09-30T12:00:00Z", + }, + ) + + assert response.status_code == 200 + data = response.json() + + assert data.get("task_group_map_index_template") == "{{ filename }}" + assert data.get("task_group_expanded_args") == {"filename": "a.json"} + + def test_ti_run_no_task_group_fields_for_non_mapped( + self, client, session, create_task_instance, time_machine + ): + """Test that task_group fields are not set for non-mapped task instances.""" + instant_str = "2024-09-30T12:00:00Z" + instant = timezone.parse(instant_str) + time_machine.move_to(instant, tick=False) + + ti = create_task_instance( + task_id="test_non_mapped", + state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, + session=session, + start_date=instant, + dag_id=str(uuid4()), + ) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": instant_str, + }, + ) + + assert response.status_code == 200 + data = response.json() + + # Non-mapped tasks should not have task group fields + assert "task_group_map_index_template" not in data + assert "task_group_expanded_args" not in data + def test_dynamic_task_mapping_with_all_success_trigger_rule(self, dag_maker: DagMaker, session: Session): """Test that with ALL_SUCCESS trigger rule and skipped upstream, downstream should not run.""" diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 41c74efa2c924..7041d1f2fa9da 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -207,6 +207,7 @@ def _operator_defaults(overrides): "tooltip": "", "ui_color": "CornflowerBlue", "ui_fgcolor": "#000", + "map_index_template": None, "upstream_group_ids": [], "downstream_group_ids": [], "upstream_task_ids": [], @@ -1714,6 +1715,30 @@ def check_task_group(node): check_task_group(serialized_dag.task_group) + def test_task_group_map_index_template_serialization(self): + """Test that map_index_template on TaskGroup survives serialization round-trip.""" + from airflow.providers.standard.operators.empty import EmptyOperator + + with DAG("test_tg_map_index_template", schedule=None, start_date=datetime(2020, 1, 1)) as dag: + with TaskGroup("my_group", map_index_template="{{ filename }}"): + EmptyOperator(task_id="task1") + + serialized_dag = DagSerialization.deserialize_dag(DagSerialization.serialize_dag(dag)) + tg = serialized_dag.task_group.children["my_group"] + assert tg.map_index_template == "{{ filename }}" + + def test_task_group_map_index_template_none_by_default(self): + """Test that map_index_template defaults to None when not set.""" + from airflow.providers.standard.operators.empty import EmptyOperator + + with DAG("test_tg_mit_default", schedule=None, start_date=datetime(2020, 1, 1)) as dag: + with TaskGroup("my_group"): + EmptyOperator(task_id="task1") + + serialized_dag = DagSerialization.deserialize_dag(DagSerialization.serialize_dag(dag)) + tg = serialized_dag.task_group.children["my_group"] + assert tg.map_index_template is None + @staticmethod def assert_taskgroup_children(se_task_group, dag_task_group, expected_children): assert se_task_group.children.keys() == dag_task_group.children.keys() == expected_children @@ -3100,6 +3125,7 @@ def tg(a: str) -> None: }, "group_display_name": "", "is_mapped": True, + "map_index_template": None, "prefix_group_id": True, "tooltip": "", "ui_color": "CornflowerBlue", @@ -3113,6 +3139,7 @@ def tg(a: str) -> None: serde_tg = serde_dag.task_group.children["tg"] assert isinstance(serde_tg, SerializedTaskGroup) assert serde_tg._expand_input == SchedulerDictOfListsExpandInput({"a": [".", ".."]}) + assert serde_tg.map_index_template is None @pytest.mark.db_test diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 32824a48b219e..c50fac1a49959 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -27,7 +27,7 @@ from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, RootModel -API_VERSION: Final[str] = "2026-03-31" +API_VERSION: Final[str] = "2026-04-15" class AssetAliasReferenceAssetEventDagRun(BaseModel): @@ -644,3 +644,5 @@ class TIRunContext(BaseModel): next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next Kwargs")] = None xcom_keys_to_clear: Annotated[list[str] | None, Field(title="Xcom Keys To Clear")] = None should_retry: Annotated[bool | None, Field(title="Should Retry")] = False + task_group_map_index_template: Annotated[str | None, Field(title="Task Group Map Index Template")] = None + task_group_expanded_args: Annotated[dict[str, Any] | None, Field(title="Task Group Expanded Args")] = None diff --git a/task-sdk/src/airflow/sdk/definitions/decorators/task_group.py b/task-sdk/src/airflow/sdk/definitions/decorators/task_group.py index 44baff52c1d27..34cb29fb87792 100644 --- a/task-sdk/src/airflow/sdk/definitions/decorators/task_group.py +++ b/task-sdk/src/airflow/sdk/definitions/decorators/task_group.py @@ -196,6 +196,7 @@ def task_group( tooltip: str = "", ui_color: str = "CornflowerBlue", ui_fgcolor: str = "#000", + map_index_template: str | None = None, add_suffix_on_collision: bool = False, group_display_name: str = "", ) -> Callable[[Callable[FParams, FReturn]], _TaskGroupFactory[FParams, FReturn]]: ... diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py b/task-sdk/src/airflow/sdk/definitions/taskgroup.py index c47b1f360aea6..4889627c3e2f9 100644 --- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py @@ -105,6 +105,10 @@ class TaskGroup(DAGNode): :param add_suffix_on_collision: If this task group name already exists, automatically add `__1` etc suffixes :param group_display_name: If set, this will be the display name for the TaskGroup node in the UI. + :param map_index_template: A Jinja2 template string used to render a meaningful label for each + mapped instance of this task group. The template has access to the group's expanded arguments. + For example, ``"{{ filename }}"`` when the group is expanded with ``filename=["a.json", "b.json"]`` + will label the mapped instances as ``"a.json"`` and ``"b.json"`` for ALL child tasks. """ _group_id: str | None = attrs.field( @@ -134,6 +138,8 @@ class TaskGroup(DAGNode): ui_color: str = attrs.field(default="CornflowerBlue", validator=attrs.validators.instance_of(str)) ui_fgcolor: str = attrs.field(default="#000", validator=attrs.validators.instance_of(str)) + map_index_template: str | None = attrs.field(default=None) + add_suffix_on_collision: bool = False @dag.validator diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 2b3933c107115..57bc3dfa7acb5 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -272,6 +272,17 @@ def get_template_context(self) -> Context: if upstream_map_indexes is not None: setattr(self, "_upstream_map_indexes", upstream_map_indexes) + # Kept separate from the main context to avoid conflicts with op_kwargs + # in PythonOperator.determine_kwargs — only used during _render_map_index. + task_group_expanded_args = getattr(from_server, "task_group_expanded_args", None) + if task_group_expanded_args: + self._task_group_expanded_args = task_group_expanded_args + + if not self.task.map_index_template: + task_group_template = getattr(from_server, "task_group_map_index_template", None) + if task_group_template: + self._cached_template_context["map_index_template"] = task_group_template + return self._cached_template_context def render_templates( @@ -1618,7 +1629,11 @@ def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger) -> return None log.debug("Rendering map_index_template", template_length=len(template)) jinja_env = ti.task.dag.get_template_env() - rendered_map_index = jinja_env.from_string(template).render(context) + render_context: dict[str, Any] = dict(context) + task_group_args = getattr(ti, "_task_group_expanded_args", None) + if task_group_args: + render_context.update(task_group_args) + rendered_map_index = jinja_env.from_string(template).render(render_context) log.debug("Map index rendered", length=len(rendered_map_index)) return rendered_map_index diff --git a/task-sdk/tests/task_sdk/definitions/decorators/test_task_group.py b/task-sdk/tests/task_sdk/definitions/decorators/test_task_group.py index 26e06cbfd6db2..4561a949a6f8b 100644 --- a/task-sdk/tests/task_sdk/definitions/decorators/test_task_group.py +++ b/task-sdk/tests/task_sdk/definitions/decorators/test_task_group.py @@ -327,3 +327,38 @@ def tg(): p = pipeline() assert p.task_group_dict["tg"].label == "my_custom_name" + + +def test_task_group_map_index_template_attribute(): + """Test that map_index_template is stored on the TaskGroup when set via the decorator.""" + + @dag(schedule=None, start_date=pendulum.datetime(2022, 1, 1)) + def pipeline(): + @task_group(map_index_template="{{ filename }}") + def tg(): + pass + + tg() + + p = pipeline() + + assert p.task_group_dict["tg"].map_index_template == "{{ filename }}" + + +def test_task_group_map_index_template_with_expand(): + """Test that map_index_template works with expand(), creating a MappedTaskGroup that carries the template.""" + + @dag(schedule=None, start_date=pendulum.datetime(2022, 1, 1)) + def pipeline(): + @task_group(map_index_template="{{ filename }}") + def tg(filename): + pass + + tg.expand(filename=["a.json", "b.json"]) + + d = pipeline() + + tg = d.task_group_dict["tg"] + assert isinstance(tg, MappedTaskGroup) + assert tg.map_index_template == "{{ filename }}" + assert tg._expand_input == DictOfListsExpandInput({"filename": ["a.json", "b.json"]}) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 8724877853f2a..0e0347d9f7b87 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -1466,6 +1466,58 @@ def test_function(ti): assert ti.rendered_map_index == "Label: test_task" +def test_task_group_map_index_template_injected_into_context(create_runtime_ti, mock_supervisor_comms): + """Test that task_group_map_index_template from server is used when task has no map_index_template.""" + + task = BaseOperator(task_id="test_task") + + ti = create_runtime_ti(task=task, dag_id="dag_with_tg_map_index_template") + + ti._ti_context_from_server.task_group_map_index_template = "{{ filename }}" + ti._ti_context_from_server.task_group_expanded_args = {"filename": "data1.json"} + ti._cached_template_context = None + + context = ti.get_template_context() + + # Expanded args are kept out of the main context to avoid conflicts with op_kwargs + assert "filename" not in context + assert ti._task_group_expanded_args == {"filename": "data1.json"} + assert context["map_index_template"] == "{{ filename }}" + + +def test_task_level_map_index_template_takes_precedence(create_runtime_ti, mock_supervisor_comms): + """Test that task-level map_index_template takes precedence over group-level.""" + + task = BaseOperator(task_id="test_task", map_index_template="Task: {{ task.task_id }}") + + ti = create_runtime_ti(task=task, dag_id="dag_with_both_templates") + + ti._ti_context_from_server.task_group_map_index_template = "{{ filename }}" + ti._ti_context_from_server.task_group_expanded_args = {"filename": "data1.json"} + ti._cached_template_context = None + + context = ti.get_template_context() + + assert ti._task_group_expanded_args == {"filename": "data1.json"} + assert context["map_index_template"] == "Task: {{ task.task_id }}" + + +def test_task_group_expanded_args_used_in_map_index_rendering(create_runtime_ti, mock_supervisor_comms): + """Test that group expanded args are available during map_index_template rendering.""" + + task = BaseOperator(task_id="test_task") + + ti = create_runtime_ti(task=task, dag_id="dag_with_tg_rendering") + + ti._ti_context_from_server.task_group_map_index_template = "{{ filename }}" + ti._ti_context_from_server.task_group_expanded_args = {"filename": "data1.json"} + ti._cached_template_context = None + + run(ti, ti.get_template_context(), log=mock.MagicMock()) + + assert ti.rendered_map_index == "data1.json" + + class TestRuntimeTaskInstance: def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context): """Test get_template_context without ti_context_from_server."""