Skip to content

Commit b7cb84e

Browse files
Fix get_ti_count and get_task_states access in callbackrequests (apache#56822)
* Fix get_ti_count and get_task_states access in callbackrequests * Add tests
1 parent da66c41 commit b7cb84e

File tree

2 files changed

+126
-0
lines changed

2 files changed

+126
-0
lines changed

airflow-core/src/airflow/dag_processing/processor.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
GetConnection,
4545
GetPreviousDagRun,
4646
GetPrevSuccessfulDagRun,
47+
GetTaskStates,
48+
GetTICount,
4749
GetVariable,
4850
GetXCom,
4951
GetXComCount,
@@ -54,6 +56,7 @@
5456
PreviousDagRunResult,
5557
PrevSuccessfulDagRunResult,
5658
PutVariable,
59+
TaskStatesResult,
5760
VariableResult,
5861
XComCountResponse,
5962
XComResult,
@@ -116,6 +119,8 @@ class DagFileParsingResult(BaseModel):
116119
| GetConnection
117120
| GetVariable
118121
| PutVariable
122+
| GetTaskStates
123+
| GetTICount
119124
| DeleteVariable
120125
| GetPrevSuccessfulDagRun
121126
| GetPreviousDagRun
@@ -131,6 +136,7 @@ class DagFileParsingResult(BaseModel):
131136
DagFileParseRequest
132137
| ConnectionResult
133138
| VariableResult
139+
| TaskStatesResult
134140
| PreviousDagRunResult
135141
| PrevSuccessfulDagRunResult
136142
| ErrorResponse
@@ -517,6 +523,7 @@ def _on_child_started(
517523
def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int) -> None:
518524
from airflow.sdk.api.datamodels._generated import (
519525
ConnectionResponse,
526+
TaskStatesResponse,
520527
VariableResponse,
521528
XComSequenceIndexResponse,
522529
)
@@ -589,6 +596,29 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int
589596
from airflow.sdk.log import mask_secret
590597

591598
mask_secret(msg.value, msg.name)
599+
elif isinstance(msg, GetTICount):
600+
resp = self.client.task_instances.get_count(
601+
dag_id=msg.dag_id,
602+
map_index=msg.map_index,
603+
task_ids=msg.task_ids,
604+
task_group_id=msg.task_group_id,
605+
logical_dates=msg.logical_dates,
606+
run_ids=msg.run_ids,
607+
states=msg.states,
608+
)
609+
elif isinstance(msg, GetTaskStates):
610+
task_states_map = self.client.task_instances.get_task_states(
611+
dag_id=msg.dag_id,
612+
map_index=msg.map_index,
613+
task_ids=msg.task_ids,
614+
task_group_id=msg.task_group_id,
615+
logical_dates=msg.logical_dates,
616+
run_ids=msg.run_ids,
617+
)
618+
if isinstance(task_states_map, TaskStatesResponse):
619+
resp = TaskStatesResult.from_api_response(task_states_map)
620+
else:
621+
resp = task_states_map
592622
else:
593623
log.error("Unhandled request", msg=msg)
594624
self.send_msg(

airflow-core/tests/unit/dag_processing/test_processor.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,12 @@
6868
from airflow.sdk.api.datamodels._generated import DagRunState
6969
from airflow.sdk.execution_time import comms
7070
from airflow.sdk.execution_time.comms import (
71+
GetTaskStates,
72+
GetTICount,
7173
GetXCom,
7274
GetXComSequenceSlice,
75+
TaskStatesResult,
76+
TICount,
7377
ToSupervisor,
7478
ToTask,
7579
XComResult,
@@ -1046,6 +1050,98 @@ def fake_collect_dags(self, *args, **kwargs):
10461050

10471051
mock_supervisor_comms.send.assert_called_once_with(msg=expected_message)
10481052

1053+
@pytest.mark.parametrize(
1054+
"request_operation,operation_type,mock_response,operation_response",
1055+
[
1056+
(
1057+
lambda context: context["task_instance"].get_ti_count(dag_id="test_dag"),
1058+
GetTICount(dag_id="test_dag"),
1059+
TICount(count=2),
1060+
"Got response 2",
1061+
),
1062+
(
1063+
lambda context: context["task_instance"].get_task_states(
1064+
dag_id="test_dag", task_ids=["test_task"]
1065+
),
1066+
GetTaskStates(
1067+
dag_id="test_dag",
1068+
task_ids=["test_task"],
1069+
),
1070+
TaskStatesResult(task_states={"test_run": {"task1": "running"}}),
1071+
"Got response {'test_run': {'task1': 'running'}}",
1072+
),
1073+
],
1074+
)
1075+
def test_dagfileprocessorprocess_request_handler_operations(
1076+
self,
1077+
spy_agency,
1078+
mock_supervisor_comms,
1079+
request_operation,
1080+
operation_type,
1081+
mock_response,
1082+
operation_response,
1083+
caplog,
1084+
):
1085+
"""Test that DagFileProcessorProcess Request Handler Operations"""
1086+
1087+
mock_supervisor_comms.send.return_value = mock_response
1088+
1089+
def callback_fn(context):
1090+
log = structlog.get_logger()
1091+
log.info("Callback started..")
1092+
log.info("Got response %s", request_operation(context))
1093+
1094+
with DAG(dag_id="test_dag", on_success_callback=callback_fn) as dag:
1095+
BaseOperator(task_id="test_task")
1096+
1097+
def fake_collect_dags(self, *args, **kwargs):
1098+
self.dags[dag.dag_id] = dag
1099+
1100+
spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag)
1101+
1102+
dagbag = DagBag()
1103+
dagbag.collect_dags()
1104+
1105+
current_time = timezone.utcnow()
1106+
request = DagCallbackRequest(
1107+
filepath="test.py",
1108+
dag_id="test_dag",
1109+
run_id="test_run",
1110+
bundle_name="testing",
1111+
bundle_version=None,
1112+
context_from_server=DagRunContext(
1113+
dag_run=DRDataModel(
1114+
dag_id="test_dag",
1115+
run_id="test_run",
1116+
logical_date=current_time,
1117+
data_interval_start=current_time,
1118+
data_interval_end=current_time,
1119+
run_after=current_time,
1120+
start_date=current_time,
1121+
end_date=None,
1122+
run_type="manual",
1123+
state="success",
1124+
consumed_asset_events=[],
1125+
),
1126+
last_ti=TIDataModel(
1127+
id=uuid.uuid4(),
1128+
dag_id="test_dag",
1129+
task_id="test_task",
1130+
run_id="test_run",
1131+
map_index=-1,
1132+
try_number=1,
1133+
dag_version_id=uuid.uuid4(),
1134+
),
1135+
),
1136+
is_failure_callback=False,
1137+
msg="Test success message",
1138+
)
1139+
1140+
_execute_dag_callbacks(dagbag, request, structlog.get_logger())
1141+
1142+
mock_supervisor_comms.send.assert_called_once_with(msg=operation_type)
1143+
assert operation_response in caplog.text
1144+
10491145

10501146
class TestExecuteTaskCallbacks:
10511147
"""Test the _execute_task_callbacks function"""

0 commit comments

Comments
 (0)