Skip to content

Commit 63b4751

Browse files
authored
[API] Use BulkTaskInstanceBody for patching tis with new state (apache#57226)
1 parent cd32d1b commit 63b4751

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,11 +923,23 @@ def patch_task_instance(
923923

924924
for key, _ in data.items():
925925
if key == "new_state":
926+
# Create BulkTaskInstanceBody object with map_index field
927+
bulk_ti_body = BulkTaskInstanceBody(
928+
task_id=task_id,
929+
map_index=map_index,
930+
new_state=body.new_state,
931+
note=body.note,
932+
include_upstream=body.include_upstream,
933+
include_downstream=body.include_downstream,
934+
include_future=body.include_future,
935+
include_past=body.include_past,
936+
)
937+
926938
_patch_task_instance_state(
927939
task_id=task_id,
928940
dag_run_id=dag_run_id,
929941
dag=dag,
930-
task_instance_body=body,
942+
task_instance_body=bulk_ti_body,
931943
data=data,
932944
session=session,
933945
)

airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3904,8 +3904,16 @@ def test_should_update_mapped_task_instance_state(self, test_client, session):
39043904
ti = TaskInstance(
39053905
task=tis[0].task, run_id=tis[0].run_id, map_index=map_index, dag_version_id=tis[0].dag_version_id
39063906
)
3907+
ti_2 = TaskInstance(
3908+
task=tis[0].task,
3909+
run_id=tis[0].run_id,
3910+
map_index=map_index + 1,
3911+
dag_version_id=tis[0].dag_version_id,
3912+
)
39073913
ti.rendered_task_instance_fields = RTIF(ti, render_templates=False)
3914+
ti_2.rendered_task_instance_fields = RTIF(ti_2, render_templates=False)
39083915
session.add(ti)
3916+
session.add(ti_2)
39093917
session.commit()
39103918

39113919
response = test_client.patch(
@@ -3920,6 +3928,11 @@ def test_should_update_mapped_task_instance_state(self, test_client, session):
39203928
assert response2.status_code == 200
39213929
assert response2.json()["state"] == self.NEW_STATE
39223930

3931+
response3 = test_client.get(f"{self.ENDPOINT_URL}/{map_index + 1}")
3932+
assert response3.status_code == 200
3933+
assert response3.json()["state"] != self.NEW_STATE
3934+
assert response3.json()["state"] is None
3935+
39233936
def test_should_update_mapped_task_instance_summary_state(self, test_client, session):
39243937
tis = self.create_task_instances(session)
39253938

0 commit comments

Comments
 (0)