|
65 | 65 | ) |
66 | 66 | from airflow.sdk.bases.xcom import BaseXCom |
67 | 67 | from airflow.sdk.definitions._internal.types import NOTSET, SET_DURING_EXECUTION, ArgNotSet |
68 | | -from airflow.sdk.definitions.asset import Asset, AssetAlias, Dataset, Model |
| 68 | +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, Dataset, Model |
69 | 69 | from airflow.sdk.definitions.param import DagParam |
70 | 70 | from airflow.sdk.exceptions import ErrorType |
71 | 71 | from airflow.sdk.execution_time.comms import ( |
@@ -2482,6 +2482,32 @@ def on_task_instance_failed(self, previous_state, task_instance, error): |
2482 | 2482 | def before_stopping(self, component): |
2483 | 2483 | self.component = component |
2484 | 2484 |
|
| 2485 | + class CustomOutletEventsListener: |
| 2486 | + def __init__(self): |
| 2487 | + self.outlet_events = [] |
| 2488 | + self.error = None |
| 2489 | + |
| 2490 | + def _add_outlet_events(self, context): |
| 2491 | + outlets = context["outlets"] |
| 2492 | + for outlet in outlets: |
| 2493 | + self.outlet_events.append(context["outlet_events"][outlet]) |
| 2494 | + |
| 2495 | + @hookimpl |
| 2496 | + def on_task_instance_running(self, previous_state, task_instance): |
| 2497 | + context = task_instance.get_template_context() |
| 2498 | + self._add_outlet_events(context) |
| 2499 | + |
| 2500 | + @hookimpl |
| 2501 | + def on_task_instance_success(self, previous_state, task_instance): |
| 2502 | + context = task_instance.get_template_context() |
| 2503 | + self._add_outlet_events(context) |
| 2504 | + |
| 2505 | + @hookimpl |
| 2506 | + def on_task_instance_failed(self, previous_state, task_instance, error): |
| 2507 | + context = task_instance.get_template_context() |
| 2508 | + self._add_outlet_events(context) |
| 2509 | + self.error = error |
| 2510 | + |
2485 | 2511 | @pytest.fixture(autouse=True) |
2486 | 2512 | def clean_listener_manager(self): |
2487 | 2513 | lm = get_listener_manager() |
@@ -2601,6 +2627,118 @@ def execute(self, context): |
2601 | 2627 | assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.FAILED] |
2602 | 2628 | assert listener.error == error |
2603 | 2629 |
|
| 2630 | + def test_listener_access_outlet_event_on_running_and_success(self, mocked_parse, mock_supervisor_comms): |
| 2631 | + """Test listener can access outlet events through invoking get_template_context() while task running and success""" |
| 2632 | + listener = self.CustomOutletEventsListener() |
| 2633 | + get_listener_manager().add_listener(listener) |
| 2634 | + |
| 2635 | + test_asset = Asset("test-asset") |
| 2636 | + test_key = AssetUniqueKey(name="test-asset", uri="test-asset") |
| 2637 | + test_extra = {"name1": "value1", "nested_obj": {"name2": "value2"}} |
| 2638 | + |
| 2639 | + class Producer(BaseOperator): |
| 2640 | + def execute(self, context): |
| 2641 | + outlet_events = context["outlet_events"] |
| 2642 | + outlet_events[test_asset].extra = test_extra |
| 2643 | + |
| 2644 | + task = Producer( |
| 2645 | + task_id="test_listener_access_outlet_event_on_running_and_success", outlets=[test_asset] |
| 2646 | + ) |
| 2647 | + dag = get_inline_dag(dag_id="test_dag", task=task) |
| 2648 | + ti = TaskInstance( |
| 2649 | + id=uuid7(), |
| 2650 | + task_id=task.task_id, |
| 2651 | + dag_id=dag.dag_id, |
| 2652 | + run_id="test_run", |
| 2653 | + try_number=1, |
| 2654 | + dag_version_id=uuid7(), |
| 2655 | + ) |
| 2656 | + |
| 2657 | + runtime_ti = RuntimeTaskInstance.model_construct( |
| 2658 | + **ti.model_dump(exclude_unset=True), task=task, start_date=timezone.utcnow() |
| 2659 | + ) |
| 2660 | + |
| 2661 | + log = mock.MagicMock() |
| 2662 | + context = runtime_ti.get_template_context() |
| 2663 | + |
| 2664 | + with mock.patch( |
| 2665 | + "airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets" |
| 2666 | + ) as validate_mock: |
| 2667 | + state, _, _ = run(runtime_ti, context, log) |
| 2668 | + |
| 2669 | + validate_mock.assert_called_once() |
| 2670 | + |
| 2671 | + outlet_event_accessor = listener.outlet_events.pop() |
| 2672 | + assert outlet_event_accessor.key == test_key |
| 2673 | + assert outlet_event_accessor.extra == test_extra |
| 2674 | + |
| 2675 | + finalize(runtime_ti, state, context, log) |
| 2676 | + |
| 2677 | + outlet_event_accessor = listener.outlet_events.pop() |
| 2678 | + assert outlet_event_accessor.key == test_key |
| 2679 | + assert outlet_event_accessor.extra == test_extra |
| 2680 | + |
| 2681 | + @pytest.mark.parametrize( |
| 2682 | + "exception", |
| 2683 | + [ |
| 2684 | + ValueError("oops"), |
| 2685 | + SystemExit("oops"), |
| 2686 | + AirflowException("oops"), |
| 2687 | + ], |
| 2688 | + ids=["ValueError", "SystemExit", "AirflowException"], |
| 2689 | + ) |
| 2690 | + def test_listener_access_outlet_event_on_failed(self, mocked_parse, mock_supervisor_comms, exception): |
| 2691 | + """Test listener can access outlet events through invoking get_template_context() while task failed""" |
| 2692 | + listener = self.CustomOutletEventsListener() |
| 2693 | + get_listener_manager().add_listener(listener) |
| 2694 | + |
| 2695 | + test_asset = Asset("test-asset") |
| 2696 | + test_key = AssetUniqueKey(name="test-asset", uri="test-asset") |
| 2697 | + test_extra = {"name1": "value1", "nested_obj": {"name2": "value2"}} |
| 2698 | + |
| 2699 | + class Producer(BaseOperator): |
| 2700 | + def execute(self, context): |
| 2701 | + outlet_events = context["outlet_events"] |
| 2702 | + outlet_events[test_asset].extra = test_extra |
| 2703 | + raise exception |
| 2704 | + |
| 2705 | + task = Producer(task_id="test_listener_access_outlet_event_on_failed", outlets=[test_asset]) |
| 2706 | + dag = get_inline_dag(dag_id="test_dag", task=task) |
| 2707 | + ti = TaskInstance( |
| 2708 | + id=uuid7(), |
| 2709 | + task_id=task.task_id, |
| 2710 | + dag_id=dag.dag_id, |
| 2711 | + run_id="test_run", |
| 2712 | + try_number=1, |
| 2713 | + dag_version_id=uuid7(), |
| 2714 | + ) |
| 2715 | + |
| 2716 | + runtime_ti = RuntimeTaskInstance.model_construct( |
| 2717 | + **ti.model_dump(exclude_unset=True), task=task, start_date=timezone.utcnow() |
| 2718 | + ) |
| 2719 | + |
| 2720 | + log = mock.MagicMock() |
| 2721 | + context = runtime_ti.get_template_context() |
| 2722 | + |
| 2723 | + with mock.patch( |
| 2724 | + "airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets" |
| 2725 | + ) as validate_mock: |
| 2726 | + state, _, error = run(runtime_ti, context, log) |
| 2727 | + |
| 2728 | + validate_mock.assert_called_once() |
| 2729 | + |
| 2730 | + outlet_event_accessor = listener.outlet_events.pop() |
| 2731 | + assert outlet_event_accessor.key == test_key |
| 2732 | + assert outlet_event_accessor.extra == test_extra |
| 2733 | + |
| 2734 | + finalize(runtime_ti, state, context, log, error) |
| 2735 | + |
| 2736 | + outlet_event_accessor = listener.outlet_events.pop() |
| 2737 | + assert outlet_event_accessor.key == test_key |
| 2738 | + assert outlet_event_accessor.extra == test_extra |
| 2739 | + |
| 2740 | + assert listener.error == error |
| 2741 | + |
2604 | 2742 |
|
2605 | 2743 | @pytest.mark.usefixtures("mock_supervisor_comms") |
2606 | 2744 | class TestTaskRunnerCallsCallbacks: |
|
0 commit comments