|
21 | 21 | import json |
22 | 22 | import os |
23 | 23 | import uuid |
| 24 | +from collections.abc import Iterable |
24 | 25 | from datetime import datetime, timedelta |
25 | 26 | from pathlib import Path |
26 | 27 | from socket import socketpair |
|
93 | 94 | ) |
94 | 95 | from airflow.utils import timezone |
95 | 96 | from airflow.utils.state import TaskInstanceState |
| 97 | +from airflow.utils.types import NOTSET, ArgNotSet |
96 | 98 |
|
97 | 99 | from tests_common.test_utils.mock_operators import AirflowLink |
98 | 100 |
|
@@ -1134,28 +1136,37 @@ def test_get_variable_from_context( |
1134 | 1136 | "push_task", |
1135 | 1137 | ["push_task1", "push_task2"], |
1136 | 1138 | {"push_task1", "push_task2"}, |
| 1139 | + None, |
| 1140 | + NOTSET, |
1137 | 1141 | ], |
1138 | 1142 | ) |
1139 | 1143 | def test_xcom_pull(self, create_runtime_ti, mock_supervisor_comms, spy_agency, task_ids): |
1140 | 1144 | """Test that a task pulls the expected XCom value if it exists.""" |
1141 | 1145 |
|
1142 | 1146 | class CustomOperator(BaseOperator): |
1143 | 1147 | def execute(self, context): |
1144 | | - value = context["ti"].xcom_pull(task_ids=task_ids, key="key") |
| 1148 | + if isinstance(task_ids, ArgNotSet): |
| 1149 | + value = context["ti"].xcom_pull(key="key") |
| 1150 | + else: |
| 1151 | + value = context["ti"].xcom_pull(task_ids=task_ids, key="key") |
1145 | 1152 | print(f"Pulled XCom Value: {value}") |
1146 | 1153 |
|
1147 | | - task = CustomOperator(task_id="pull_task") |
| 1154 | + test_task_id = "pull_task" |
| 1155 | + task = CustomOperator(task_id=test_task_id) |
1148 | 1156 |
|
1149 | 1157 | runtime_ti = create_runtime_ti(task=task) |
1150 | 1158 |
|
1151 | 1159 | mock_supervisor_comms.get_message.return_value = XComResult(key="key", value='"value"') |
1152 | 1160 |
|
1153 | 1161 | run(runtime_ti, log=mock.MagicMock()) |
1154 | 1162 |
|
1155 | | - if isinstance(task_ids, str): |
| 1163 | + if not isinstance(task_ids, Iterable) or isinstance(task_ids, str): |
1156 | 1164 | task_ids = [task_ids] |
1157 | 1165 |
|
1158 | 1166 | for task_id in task_ids: |
| 1167 | + # Without task_ids (or None) expected behavior is to pull with calling task_id |
| 1168 | + if task_id is None or isinstance(task_id, ArgNotSet): |
| 1169 | + task_id = test_task_id |
1159 | 1170 | mock_supervisor_comms.send_request.assert_any_call( |
1160 | 1171 | log=mock.ANY, |
1161 | 1172 | msg=GetXCom( |
|
0 commit comments