-
Notifications
You must be signed in to change notification settings - Fork 16.5k
Add map_index_template support for mapped task group (#40799) #61975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -264,6 +264,11 @@ def ti_run( | |
| context.next_method = ti.next_method | ||
| context.next_kwargs = ti.next_kwargs | ||
|
|
||
| if ti.map_index >= 0: | ||
| _populate_task_group_map_index_context( | ||
| context, ti.dag_id, ti.task_id, ti.map_index, ti.run_id, session, dag_bag | ||
| ) | ||
|
|
||
| return context | ||
| except SQLAlchemyError: | ||
| log.exception("Error marking Task Instance state as running") | ||
|
|
@@ -921,6 +926,100 @@ def _iter_breadcrumbs() -> Iterator[dict[str, Any]]: | |
| return TaskBreadcrumbsResponse(breadcrumbs=_iter_breadcrumbs()) | ||
|
|
||
|
|
||
| def _populate_task_group_map_index_context( | ||
| context: TIRunContext, | ||
| dag_id: str, | ||
| task_id: str, | ||
| map_index: int, | ||
| run_id: str, | ||
| session: SessionDep, | ||
| dag_bag: DagBagDep, | ||
| ) -> None: | ||
| """Populate task group map_index_template and expanded args on the TIRunContext.""" | ||
| try: | ||
| dag = get_latest_version_of_dag(dag_bag, dag_id, session) | ||
| except HTTPException: | ||
| return | ||
|
|
||
| task = dag.task_dict.get(task_id) | ||
| if not task: | ||
| return | ||
|
|
||
| # iter_mapped_task_groups walks from innermost to outermost; we use the first match. | ||
| for mtg in task.iter_mapped_task_groups(): | ||
| if not mtg.map_index_template: | ||
| continue | ||
|
|
||
| context.task_group_map_index_template = mtg.map_index_template | ||
| context.task_group_expanded_args = _resolve_task_group_expand_args( | ||
| mtg._expand_input, map_index, run_id, session | ||
| ) | ||
| break | ||
|
|
||
|
|
||
| def _resolve_task_group_expand_args( | ||
| expand_input: Any, | ||
| map_index: int, | ||
| run_id: str, | ||
| session: SessionDep, | ||
| ) -> dict[str, Any] | None: | ||
| """Resolve the expand_input for a specific map_index to get the expanded arguments.""" | ||
| from airflow.models.expandinput import SchedulerDictOfListsExpandInput, SchedulerListOfDictsExpandInput | ||
| from airflow.serialization.definitions.xcom_arg import SchedulerXComArg | ||
|
|
||
| def _resolve_at_index(value: Any) -> Any | None: | ||
| """Resolve a single value (list/tuple or XComArg) at the given map_index.""" | ||
| match value: | ||
| case SchedulerXComArg(): | ||
| value = _resolve_xcom_arg_value(value, run_id, session) | ||
| case list() | tuple(): | ||
| pass | ||
| case _: | ||
| return None | ||
| if isinstance(value, (list, tuple)) and map_index < len(value): | ||
| return value[map_index] | ||
| return None | ||
|
|
||
| match expand_input: | ||
| case SchedulerDictOfListsExpandInput(value=mapping): | ||
| resolved = {} | ||
| for key, val in mapping.items(): | ||
| if (item := _resolve_at_index(val)) is not None: | ||
| resolved[key] = item | ||
| return resolved or None | ||
|
|
||
| case SchedulerListOfDictsExpandInput(value=val): | ||
| if isinstance(item := _resolve_at_index(val), dict): | ||
| return item | ||
|
|
||
| return None | ||
anishgirianish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def _resolve_xcom_arg_value(xcom_arg: Any, run_id: str, session: SessionDep) -> Any: | ||
| """Resolve a SchedulerXComArg to its actual value via XCom query.""" | ||
| refs = list(xcom_arg.iter_references()) | ||
| if not refs: | ||
| return None | ||
| operator, key = refs[0] | ||
|
|
||
| xcom_value = session.scalar( | ||
| select(XComModel.value).where( | ||
| XComModel.dag_id == operator.dag_id, | ||
| XComModel.task_id == operator.task_id, | ||
| XComModel.run_id == run_id, | ||
| XComModel.key == key, | ||
| XComModel.map_index == -1, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this mean the group's expand input can't depend on mapped upstream task outputs? If so, this should be documented as a limitation.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great question! From what I can tell, the upstream producing the expand input is unmapped; it returns a list that drives the mapping. If it were mapped, it would go through expand_kwargs instead, so it wouldn't be a SchedulerXComArg here. But I'd love to hear if you've seen a case where this breaks,
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not that familiar with xcom intricacies - perhaps it's worth adding a test that ensures this scenario works correctly. |
||
| ) | ||
| ) | ||
| if xcom_value is None: | ||
| return None | ||
| try: | ||
| return json.loads(xcom_value) | ||
| except (json.JSONDecodeError, TypeError): | ||
| log.debug("Failed to decode XCom value for task_group expand args", exc_info=True) | ||
| return None | ||
anishgirianish marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool: | ||
| """Is task instance is eligible for retry.""" | ||
| if state == TaskInstanceState.RESTARTING: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from cadwyn import ResponseInfo, VersionChange, convert_response_to_previous_version_for, schema | ||
|
|
||
| from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRunContext | ||
|
|
||
|
|
||
| class AddTaskGroupMapIndexTemplateFields(VersionChange): | ||
| """Add task_group_map_index_template and task_group_expanded_args fields to TIRunContext.""" | ||
|
|
||
| description = __doc__ | ||
|
|
||
| instructions_to_migrate_to_previous_version = ( | ||
| schema(TIRunContext).field("task_group_map_index_template").didnt_exist, | ||
| schema(TIRunContext).field("task_group_expanded_args").didnt_exist, | ||
| ) | ||
|
|
||
| @convert_response_to_previous_version_for(TIRunContext) # type: ignore[arg-type] | ||
| def remove_task_group_fields(response: ResponseInfo) -> None: # type: ignore[misc] | ||
| """Remove task group map index fields for older API versions.""" | ||
| response.body.pop("task_group_map_index_template", None) | ||
| response.body.pop("task_group_expanded_args", None) |
Uh oh!
There was an error while loading. Please reload this page.