Skip to content

Commit 2ea7aed

Browse files
authored
Fix xcom_pull for task_ids=None (apache#47407)
1 parent b3bccab commit 2ea7aed

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

task-sdk/src/airflow/sdk/execution_time/task_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def xcom_pull(
292292

293293
if task_ids is None:
294294
# default to the current task if not provided
295-
task_ids = self.task_id
295+
task_ids = [self.task_id]
296296
elif isinstance(task_ids, str):
297297
task_ids = [task_ids]
298298
if isinstance(map_indexes, ArgNotSet):

task-sdk/tests/task_sdk/execution_time/test_task_runner.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import json
2222
import os
2323
import uuid
24+
from collections.abc import Iterable
2425
from datetime import datetime, timedelta
2526
from pathlib import Path
2627
from socket import socketpair
@@ -93,6 +94,7 @@
9394
)
9495
from airflow.utils import timezone
9596
from airflow.utils.state import TaskInstanceState
97+
from airflow.utils.types import NOTSET, ArgNotSet
9698

9799
from tests_common.test_utils.mock_operators import AirflowLink
98100

@@ -1134,28 +1136,37 @@ def test_get_variable_from_context(
11341136
"push_task",
11351137
["push_task1", "push_task2"],
11361138
{"push_task1", "push_task2"},
1139+
None,
1140+
NOTSET,
11371141
],
11381142
)
11391143
def test_xcom_pull(self, create_runtime_ti, mock_supervisor_comms, spy_agency, task_ids):
11401144
"""Test that a task pulls the expected XCom value if it exists."""
11411145

11421146
class CustomOperator(BaseOperator):
11431147
def execute(self, context):
1144-
value = context["ti"].xcom_pull(task_ids=task_ids, key="key")
1148+
if isinstance(task_ids, ArgNotSet):
1149+
value = context["ti"].xcom_pull(key="key")
1150+
else:
1151+
value = context["ti"].xcom_pull(task_ids=task_ids, key="key")
11451152
print(f"Pulled XCom Value: {value}")
11461153

1147-
task = CustomOperator(task_id="pull_task")
1154+
test_task_id = "pull_task"
1155+
task = CustomOperator(task_id=test_task_id)
11481156

11491157
runtime_ti = create_runtime_ti(task=task)
11501158

11511159
mock_supervisor_comms.get_message.return_value = XComResult(key="key", value='"value"')
11521160

11531161
run(runtime_ti, log=mock.MagicMock())
11541162

1155-
if isinstance(task_ids, str):
1163+
if not isinstance(task_ids, Iterable) or isinstance(task_ids, str):
11561164
task_ids = [task_ids]
11571165

11581166
for task_id in task_ids:
1167+
# Without task_ids (or None) expected behavior is to pull with calling task_id
1168+
if task_id is None or isinstance(task_id, ArgNotSet):
1169+
task_id = test_task_id
11591170
mock_supervisor_comms.send_request.assert_any_call(
11601171
log=mock.ANY,
11611172
msg=GetXCom(

0 commit comments

Comments
 (0)