Skip to content

Commit 578ad78

Browse files
authored
Fix Task Mapping with XCOM arguments from other Tasks (apache#47141)
1 parent 7181cfa commit 578ad78

File tree

7 files changed

+119
-10
lines changed

7 files changed

+119
-10
lines changed

task_sdk/src/airflow/sdk/api/client.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,14 +298,23 @@ def get(
298298
return XComResponse.model_validate_json(resp.read())
299299

300300
def set(
301-
self, dag_id: str, run_id: str, task_id: str, key: str, value, map_index: int | None = None
301+
self,
302+
dag_id: str,
303+
run_id: str,
304+
task_id: str,
305+
key: str,
306+
value,
307+
map_index: int | None = None,
308+
mapped_length: int | None = None,
302309
) -> dict[str, bool]:
303310
"""Set a XCom value via the API server."""
304311
# TODO: check if we need to use map_index as params in the uri
305312
# ref: https://github.com/apache/airflow/blob/v2-10-stable/airflow/api_connexion/openapi/v1.yaml#L1785C1-L1785C81
306313
params = {}
307314
if map_index is not None and map_index >= 0:
308315
params = {"map_index": map_index}
316+
if mapped_length is not None and mapped_length >= 0:
317+
params["mapped_length"] = mapped_length
309318
self.client.post(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", params=params, json=value)
310319
# Any error from the server will anyway be propagated down to the supervisor,
311320
# so we choose to send a generic response to the supervisor over the server response to

task_sdk/src/airflow/sdk/definitions/xcom_arg.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -338,12 +338,23 @@ def resolve(self, context: Mapping[str, Any]) -> Any:
338338

339339
if self.operator.is_mapped:
340340
return LazyXComSequence[Any](xcom_arg=self, ti=ti)
341-
342-
result = ti.xcom_pull(
343-
task_ids=task_id,
344-
key=self.key,
345-
default=NOTSET,
346-
)
341+
tg = ti.task.get_closest_mapped_task_group()
342+
result = None
343+
if tg is None:
344+
# regular task
345+
result = ti.xcom_pull(
346+
task_ids=task_id,
347+
key=self.key,
348+
default=NOTSET,
349+
map_indexes=None,
350+
)
351+
else:
352+
# task from a task group
353+
result = ti.xcom_pull(
354+
task_ids=task_id,
355+
key=self.key,
356+
default=NOTSET,
357+
)
347358
if not isinstance(result, ArgNotSet):
348359
return result
349360
if self.key == XCOM_RETURN_KEY:

task_sdk/src/airflow/sdk/execution_time/supervisor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,9 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
827827
self._terminal_state = IntermediateTIState.UP_FOR_RESCHEDULE
828828
self.client.task_instances.reschedule(self.id, msg)
829829
elif isinstance(msg, SetXCom):
830-
self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index)
830+
self.client.xcoms.set(
831+
msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index, msg.mapped_length
832+
)
831833
elif isinstance(msg, PutVariable):
832834
self.client.variables.set(msg.key, msg.value, msg.description)
833835
elif isinstance(msg, SetRenderedFields):

task_sdk/tests/api/test_client.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,33 @@ def handle_request(request: httpx.Request) -> httpx.Response:
596596
)
597597
assert result == {"ok": True}
598598

599+
def test_xcom_set_with_mapped_length(self):
600+
# Simulate a successful response from the server when setting an xcom with mapped_length
601+
def handle_request(request: httpx.Request) -> httpx.Response:
602+
if (
603+
request.url.path == "/xcoms/dag_id/run_id/task_id/key"
604+
and request.url.params.get("map_index") == "2"
605+
and request.url.params.get("mapped_length") == "3"
606+
):
607+
assert json.loads(request.read()) == "value1"
608+
return httpx.Response(
609+
status_code=201,
610+
json={"message": "XCom successfully set"},
611+
)
612+
return httpx.Response(status_code=400, json={"detail": "Bad Request"})
613+
614+
client = make_client(transport=httpx.MockTransport(handle_request))
615+
result = client.xcoms.set(
616+
dag_id="dag_id",
617+
run_id="run_id",
618+
task_id="task_id",
619+
key="key",
620+
value="value1",
621+
map_index=2,
622+
mapped_length=3,
623+
)
624+
assert result == {"ok": True}
625+
599626

600627
class TestConnectionOperations:
601628
"""

task_sdk/tests/definitions/test_mappedoperator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,10 +570,16 @@ def tg(va):
570570
return t2
571571

572572
# The group is mapped by 3.
573-
t2 = tg.expand(va=[["a", "b"], [4], ["z"]])
573+
tg1 = tg.expand(
574+
va=[
575+
["a", "b"],
576+
[4],
577+
["z"],
578+
]
579+
)
574580

575581
# Aggregates results from task group.
576-
t.override(task_id="t3")(t2)
582+
t.override(task_id="t3")(tg1)
577583

578584
def xcom_get():
579585
# TODO: Tidy this after #45927 is reopened and fixed properly

task_sdk/tests/execution_time/test_supervisor.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,7 @@ def watched_subprocess(self, mocker):
10421042
"test_key",
10431043
'{"key": "test_key", "value": {"key2": "value2"}}',
10441044
None,
1045+
None,
10451046
),
10461047
{},
10471048
{"ok": True},
@@ -1065,11 +1066,37 @@ def watched_subprocess(self, mocker):
10651066
"test_key",
10661067
'{"key": "test_key", "value": {"key2": "value2"}}',
10671068
2,
1069+
None,
10681070
),
10691071
{},
10701072
{"ok": True},
10711073
id="set_xcom_with_map_index",
10721074
),
1075+
pytest.param(
1076+
SetXCom(
1077+
dag_id="test_dag",
1078+
run_id="test_run",
1079+
task_id="test_task",
1080+
key="test_key",
1081+
value='{"key": "test_key", "value": {"key2": "value2"}}',
1082+
map_index=2,
1083+
mapped_length=3,
1084+
),
1085+
b"",
1086+
"xcoms.set",
1087+
(
1088+
"test_dag",
1089+
"test_run",
1090+
"test_task",
1091+
"test_key",
1092+
'{"key": "test_key", "value": {"key2": "value2"}}',
1093+
2,
1094+
3,
1095+
),
1096+
{},
1097+
{"ok": True},
1098+
id="set_xcom_with_map_index_and_mapped_length",
1099+
),
10731100
# we aren't adding all states under TerminalTIState here, because this test's scope is only to check
10741101
# if it can handle TaskState message
10751102
pytest.param(

tests/api_fastapi/execution_api/routes/test_xcoms.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,33 @@ def test_xcom_set(self, client, create_task_instance, session, value, expected_v
120120
task_map = session.query(TaskMap).filter_by(task_id=ti.task_id, dag_id=ti.dag_id).one_or_none()
121121
assert task_map is None, "Should not be mapped"
122122

123+
def test_xcom_set_mapped(self, client, create_task_instance, session):
124+
ti = create_task_instance()
125+
session.commit()
126+
127+
response = client.post(
128+
f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/xcom_1",
129+
params={"map_index": -1, "mapped_length": 3},
130+
json="value1",
131+
)
132+
133+
assert response.status_code == 201
134+
assert response.json() == {"message": "XCom successfully set"}
135+
136+
xcom = (
137+
session.query(XCom)
138+
.filter_by(task_id=ti.task_id, dag_id=ti.dag_id, key="xcom_1", map_index=-1)
139+
.first()
140+
)
141+
assert xcom.value == "value1"
142+
task_map = session.query(TaskMap).filter_by(task_id=ti.task_id, dag_id=ti.dag_id).one_or_none()
143+
assert task_map is not None, "Should be mapped"
144+
assert task_map.dag_id == "dag"
145+
assert task_map.run_id == "test"
146+
assert task_map.task_id == "op1"
147+
assert task_map.map_index == -1
148+
assert task_map.length == 3
149+
123150
@pytest.mark.parametrize(
124151
("length", "err_context"),
125152
[

0 commit comments

Comments
 (0)