Skip to content

Commit e3aa919

Browse files
github-actions[bot]sjyangkevin
authored andcommitted
[v3-1-test] Fix Outlet Event Extra Data is Empty in Task Instance Success Listener (#54568) (#57031)
Co-authored-by: Kevin Yang <85313829+sjyangkevin@users.noreply.github.com>
1 parent 51817f0 commit e3aa919

File tree

3 files changed

+161
-10
lines changed

3 files changed

+161
-10
lines changed

scripts/ci/prek/check_template_context_variable_in_sync.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,25 @@ def extract_keys_from_dict(node: ast.Dict) -> typing.Iterator[str]:
8383
yield key.value
8484

8585
# Extract keys from the main `context` dictionary assignment
86-
context_assignment = next(
86+
context_assignment: ast.AnnAssign = next(
8787
stmt
8888
for stmt in fn_get_template_context.body
8989
if isinstance(stmt, ast.AnnAssign)
90-
and isinstance(stmt.target, ast.Name)
91-
and stmt.target.id == "context"
90+
and isinstance(stmt.target, ast.Attribute)
91+
and isinstance(stmt.target.value, ast.Name)
92+
and stmt.target.value.id == "self"
93+
and stmt.target.attr == "_context"
9294
)
9395

94-
if not isinstance(context_assignment.value, ast.Dict):
96+
if not isinstance(context_assignment.value, ast.BoolOp):
97+
raise TypeError("Expected a BoolOp like 'self._context or {...}'.")
98+
99+
context_assignment_op = context_assignment.value
100+
_, context_assignment_value = context_assignment_op.values
101+
102+
if not isinstance(context_assignment_value, ast.Dict):
95103
raise ValueError("'context' is not assigned a dictionary literal")
96-
yield from extract_keys_from_dict(context_assignment.value)
104+
yield from extract_keys_from_dict(context_assignment_value)
97105

98106
# Handle keys added conditionally in `if from_server`
99107
for stmt in fn_get_template_context.body:

task-sdk/src/airflow/sdk/execution_time/task_runner.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ class RuntimeTaskInstance(TaskInstance):
131131

132132
task: BaseOperator
133133
bundle_instance: BaseDagBundle
134+
_context: Context | None = None
135+
"""The Task Instance context."""
136+
134137
_ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None
135138
"""The Task Instance context from the API server, if any."""
136139

@@ -173,7 +176,9 @@ def get_template_context(self) -> Context:
173176

174177
validated_params = process_params(self.task.dag, self.task, dag_run_conf, suppress_exception=False)
175178

176-
context: Context = {
179+
# Cache the context object, which ensures that all calls to get_template_context
180+
# are operating on the same context object.
181+
self._context: Context = self._context or {
177182
# From the Task Execution interface
178183
"dag": self.task.dag,
179184
"inlets": self.task.inlets,
@@ -213,7 +218,7 @@ def get_template_context(self) -> Context:
213218
lambda: coerce_datetime(get_previous_dagrun_success(self.id).end_date)
214219
),
215220
}
216-
context.update(context_from_server)
221+
self._context.update(context_from_server)
217222

218223
if logical_date := coerce_datetime(dag_run.logical_date):
219224
if TYPE_CHECKING:
@@ -224,7 +229,7 @@ def get_template_context(self) -> Context:
224229
ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
225230
ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
226231
# logical_date and data_interval either coexist or be None together
227-
context.update(
232+
self._context.update(
228233
{
229234
# keys that depend on logical_date
230235
"logical_date": logical_date,
@@ -251,7 +256,7 @@ def get_template_context(self) -> Context:
251256
# existence. Should this be a private attribute on RuntimeTI instead perhaps?
252257
setattr(self, "_upstream_map_indexes", from_server.upstream_map_indexes)
253258

254-
return context
259+
return self._context
255260

256261
def render_templates(
257262
self, context: Context | None = None, jinja_env: jinja2.Environment | None = None

task-sdk/tests/task_sdk/execution_time/test_task_runner.py

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
)
6666
from airflow.sdk.bases.xcom import BaseXCom
6767
from airflow.sdk.definitions._internal.types import NOTSET, SET_DURING_EXECUTION, ArgNotSet
68-
from airflow.sdk.definitions.asset import Asset, AssetAlias, Dataset, Model
68+
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, Dataset, Model
6969
from airflow.sdk.definitions.param import DagParam
7070
from airflow.sdk.exceptions import ErrorType
7171
from airflow.sdk.execution_time.comms import (
@@ -2482,6 +2482,32 @@ def on_task_instance_failed(self, previous_state, task_instance, error):
24822482
def before_stopping(self, component):
24832483
self.component = component
24842484

2485+
class CustomOutletEventsListener:
2486+
def __init__(self):
2487+
self.outlet_events = []
2488+
self.error = None
2489+
2490+
def _add_outlet_events(self, context):
2491+
outlets = context["outlets"]
2492+
for outlet in outlets:
2493+
self.outlet_events.append(context["outlet_events"][outlet])
2494+
2495+
@hookimpl
2496+
def on_task_instance_running(self, previous_state, task_instance):
2497+
context = task_instance.get_template_context()
2498+
self._add_outlet_events(context)
2499+
2500+
@hookimpl
2501+
def on_task_instance_success(self, previous_state, task_instance):
2502+
context = task_instance.get_template_context()
2503+
self._add_outlet_events(context)
2504+
2505+
@hookimpl
2506+
def on_task_instance_failed(self, previous_state, task_instance, error):
2507+
context = task_instance.get_template_context()
2508+
self._add_outlet_events(context)
2509+
self.error = error
2510+
24852511
@pytest.fixture(autouse=True)
24862512
def clean_listener_manager(self):
24872513
lm = get_listener_manager()
@@ -2601,6 +2627,118 @@ def execute(self, context):
26012627
assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.FAILED]
26022628
assert listener.error == error
26032629

2630+
def test_listener_access_outlet_event_on_running_and_success(self, mocked_parse, mock_supervisor_comms):
2631+
"""Test listener can access outlet events through invoking get_template_context() while task running and success"""
2632+
listener = self.CustomOutletEventsListener()
2633+
get_listener_manager().add_listener(listener)
2634+
2635+
test_asset = Asset("test-asset")
2636+
test_key = AssetUniqueKey(name="test-asset", uri="test-asset")
2637+
test_extra = {"name1": "value1", "nested_obj": {"name2": "value2"}}
2638+
2639+
class Producer(BaseOperator):
2640+
def execute(self, context):
2641+
outlet_events = context["outlet_events"]
2642+
outlet_events[test_asset].extra = test_extra
2643+
2644+
task = Producer(
2645+
task_id="test_listener_access_outlet_event_on_running_and_success", outlets=[test_asset]
2646+
)
2647+
dag = get_inline_dag(dag_id="test_dag", task=task)
2648+
ti = TaskInstance(
2649+
id=uuid7(),
2650+
task_id=task.task_id,
2651+
dag_id=dag.dag_id,
2652+
run_id="test_run",
2653+
try_number=1,
2654+
dag_version_id=uuid7(),
2655+
)
2656+
2657+
runtime_ti = RuntimeTaskInstance.model_construct(
2658+
**ti.model_dump(exclude_unset=True), task=task, start_date=timezone.utcnow()
2659+
)
2660+
2661+
log = mock.MagicMock()
2662+
context = runtime_ti.get_template_context()
2663+
2664+
with mock.patch(
2665+
"airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets"
2666+
) as validate_mock:
2667+
state, _, _ = run(runtime_ti, context, log)
2668+
2669+
validate_mock.assert_called_once()
2670+
2671+
outlet_event_accessor = listener.outlet_events.pop()
2672+
assert outlet_event_accessor.key == test_key
2673+
assert outlet_event_accessor.extra == test_extra
2674+
2675+
finalize(runtime_ti, state, context, log)
2676+
2677+
outlet_event_accessor = listener.outlet_events.pop()
2678+
assert outlet_event_accessor.key == test_key
2679+
assert outlet_event_accessor.extra == test_extra
2680+
2681+
@pytest.mark.parametrize(
2682+
"exception",
2683+
[
2684+
ValueError("oops"),
2685+
SystemExit("oops"),
2686+
AirflowException("oops"),
2687+
],
2688+
ids=["ValueError", "SystemExit", "AirflowException"],
2689+
)
2690+
def test_listener_access_outlet_event_on_failed(self, mocked_parse, mock_supervisor_comms, exception):
2691+
"""Test listener can access outlet events through invoking get_template_context() while task failed"""
2692+
listener = self.CustomOutletEventsListener()
2693+
get_listener_manager().add_listener(listener)
2694+
2695+
test_asset = Asset("test-asset")
2696+
test_key = AssetUniqueKey(name="test-asset", uri="test-asset")
2697+
test_extra = {"name1": "value1", "nested_obj": {"name2": "value2"}}
2698+
2699+
class Producer(BaseOperator):
2700+
def execute(self, context):
2701+
outlet_events = context["outlet_events"]
2702+
outlet_events[test_asset].extra = test_extra
2703+
raise exception
2704+
2705+
task = Producer(task_id="test_listener_access_outlet_event_on_failed", outlets=[test_asset])
2706+
dag = get_inline_dag(dag_id="test_dag", task=task)
2707+
ti = TaskInstance(
2708+
id=uuid7(),
2709+
task_id=task.task_id,
2710+
dag_id=dag.dag_id,
2711+
run_id="test_run",
2712+
try_number=1,
2713+
dag_version_id=uuid7(),
2714+
)
2715+
2716+
runtime_ti = RuntimeTaskInstance.model_construct(
2717+
**ti.model_dump(exclude_unset=True), task=task, start_date=timezone.utcnow()
2718+
)
2719+
2720+
log = mock.MagicMock()
2721+
context = runtime_ti.get_template_context()
2722+
2723+
with mock.patch(
2724+
"airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets"
2725+
) as validate_mock:
2726+
state, _, error = run(runtime_ti, context, log)
2727+
2728+
validate_mock.assert_called_once()
2729+
2730+
outlet_event_accessor = listener.outlet_events.pop()
2731+
assert outlet_event_accessor.key == test_key
2732+
assert outlet_event_accessor.extra == test_extra
2733+
2734+
finalize(runtime_ti, state, context, log, error)
2735+
2736+
outlet_event_accessor = listener.outlet_events.pop()
2737+
assert outlet_event_accessor.key == test_key
2738+
assert outlet_event_accessor.extra == test_extra
2739+
2740+
assert listener.error == error
2741+
26042742

26052743
@pytest.mark.usefixtures("mock_supervisor_comms")
26062744
class TestTaskRunnerCallsCallbacks:

0 commit comments

Comments
 (0)