|
67 | 67 | ) |
68 | 68 | from airflow.sdk.bases.xcom import BaseXCom |
69 | 69 | from airflow.sdk.definitions._internal.types import NOTSET, SET_DURING_EXECUTION, ArgNotSet |
70 | | -from airflow.sdk.definitions.asset import Asset, AssetAlias, Dataset, Model |
| 70 | +from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, Dataset, Model |
71 | 71 | from airflow.sdk.definitions.param import DagParam |
72 | 72 | from airflow.sdk.exceptions import ErrorType |
73 | 73 | from airflow.sdk.execution_time.comms import ( |
@@ -2503,6 +2503,32 @@ def on_task_instance_failed(self, previous_state, task_instance, error): |
2503 | 2503 | def before_stopping(self, component): |
2504 | 2504 | self.component = component |
2505 | 2505 |
|
| 2506 | + class CustomOutletEventsListener: |
| 2507 | + def __init__(self): |
| 2508 | + self.outlet_events = [] |
| 2509 | + self.error = None |
| 2510 | + |
| 2511 | + def _add_outlet_events(self, context): |
| 2512 | + outlets = context["outlets"] |
| 2513 | + for outlet in outlets: |
| 2514 | + self.outlet_events.append(context["outlet_events"][outlet]) |
| 2515 | + |
| 2516 | + @hookimpl |
| 2517 | + def on_task_instance_running(self, previous_state, task_instance): |
| 2518 | + context = task_instance.get_template_context() |
| 2519 | + self._add_outlet_events(context) |
| 2520 | + |
| 2521 | + @hookimpl |
| 2522 | + def on_task_instance_success(self, previous_state, task_instance): |
| 2523 | + context = task_instance.get_template_context() |
| 2524 | + self._add_outlet_events(context) |
| 2525 | + |
| 2526 | + @hookimpl |
| 2527 | + def on_task_instance_failed(self, previous_state, task_instance, error): |
| 2528 | + context = task_instance.get_template_context() |
| 2529 | + self._add_outlet_events(context) |
| 2530 | + self.error = error |
| 2531 | + |
2506 | 2532 | @pytest.fixture(autouse=True) |
2507 | 2533 | def clean_listener_manager(self): |
2508 | 2534 | lm = get_listener_manager() |
@@ -2622,6 +2648,118 @@ def execute(self, context): |
2622 | 2648 | assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.FAILED] |
2623 | 2649 | assert listener.error == error |
2624 | 2650 |
|
| 2651 | + def test_listener_access_outlet_event_on_running_and_success(self, mocked_parse, mock_supervisor_comms): |
| 2652 | + """Test listener can access outlet events through invoking get_template_context() while task running and success""" |
| 2653 | + listener = self.CustomOutletEventsListener() |
| 2654 | + get_listener_manager().add_listener(listener) |
| 2655 | + |
| 2656 | + test_asset = Asset("test-asset") |
| 2657 | + test_key = AssetUniqueKey(name="test-asset", uri="test-asset") |
| 2658 | + test_extra = {"name1": "value1", "nested_obj": {"name2": "value2"}} |
| 2659 | + |
| 2660 | + class Producer(BaseOperator): |
| 2661 | + def execute(self, context): |
| 2662 | + outlet_events = context["outlet_events"] |
| 2663 | + outlet_events[test_asset].extra = test_extra |
| 2664 | + |
| 2665 | + task = Producer( |
| 2666 | + task_id="test_listener_access_outlet_event_on_running_and_success", outlets=[test_asset] |
| 2667 | + ) |
| 2668 | + dag = get_inline_dag(dag_id="test_dag", task=task) |
| 2669 | + ti = TaskInstance( |
| 2670 | + id=uuid7(), |
| 2671 | + task_id=task.task_id, |
| 2672 | + dag_id=dag.dag_id, |
| 2673 | + run_id="test_run", |
| 2674 | + try_number=1, |
| 2675 | + dag_version_id=uuid7(), |
| 2676 | + ) |
| 2677 | + |
| 2678 | + runtime_ti = RuntimeTaskInstance.model_construct( |
| 2679 | + **ti.model_dump(exclude_unset=True), task=task, start_date=timezone.utcnow() |
| 2680 | + ) |
| 2681 | + |
| 2682 | + log = mock.MagicMock() |
| 2683 | + context = runtime_ti.get_template_context() |
| 2684 | + |
| 2685 | + with mock.patch( |
| 2686 | + "airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets" |
| 2687 | + ) as validate_mock: |
| 2688 | + state, _, _ = run(runtime_ti, context, log) |
| 2689 | + |
| 2690 | + validate_mock.assert_called_once() |
| 2691 | + |
| 2692 | + outlet_event_accessor = listener.outlet_events.pop() |
| 2693 | + assert outlet_event_accessor.key == test_key |
| 2694 | + assert outlet_event_accessor.extra == test_extra |
| 2695 | + |
| 2696 | + finalize(runtime_ti, state, context, log) |
| 2697 | + |
| 2698 | + outlet_event_accessor = listener.outlet_events.pop() |
| 2699 | + assert outlet_event_accessor.key == test_key |
| 2700 | + assert outlet_event_accessor.extra == test_extra |
| 2701 | + |
| 2702 | + @pytest.mark.parametrize( |
| 2703 | + "exception", |
| 2704 | + [ |
| 2705 | + ValueError("oops"), |
| 2706 | + SystemExit("oops"), |
| 2707 | + AirflowException("oops"), |
| 2708 | + ], |
| 2709 | + ids=["ValueError", "SystemExit", "AirflowException"], |
| 2710 | + ) |
| 2711 | + def test_listener_access_outlet_event_on_failed(self, mocked_parse, mock_supervisor_comms, exception): |
| 2712 | + """Test listener can access outlet events through invoking get_template_context() while task failed""" |
| 2713 | + listener = self.CustomOutletEventsListener() |
| 2714 | + get_listener_manager().add_listener(listener) |
| 2715 | + |
| 2716 | + test_asset = Asset("test-asset") |
| 2717 | + test_key = AssetUniqueKey(name="test-asset", uri="test-asset") |
| 2718 | + test_extra = {"name1": "value1", "nested_obj": {"name2": "value2"}} |
| 2719 | + |
| 2720 | + class Producer(BaseOperator): |
| 2721 | + def execute(self, context): |
| 2722 | + outlet_events = context["outlet_events"] |
| 2723 | + outlet_events[test_asset].extra = test_extra |
| 2724 | + raise exception |
| 2725 | + |
| 2726 | + task = Producer(task_id="test_listener_access_outlet_event_on_failed", outlets=[test_asset]) |
| 2727 | + dag = get_inline_dag(dag_id="test_dag", task=task) |
| 2728 | + ti = TaskInstance( |
| 2729 | + id=uuid7(), |
| 2730 | + task_id=task.task_id, |
| 2731 | + dag_id=dag.dag_id, |
| 2732 | + run_id="test_run", |
| 2733 | + try_number=1, |
| 2734 | + dag_version_id=uuid7(), |
| 2735 | + ) |
| 2736 | + |
| 2737 | + runtime_ti = RuntimeTaskInstance.model_construct( |
| 2738 | + **ti.model_dump(exclude_unset=True), task=task, start_date=timezone.utcnow() |
| 2739 | + ) |
| 2740 | + |
| 2741 | + log = mock.MagicMock() |
| 2742 | + context = runtime_ti.get_template_context() |
| 2743 | + |
| 2744 | + with mock.patch( |
| 2745 | + "airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets" |
| 2746 | + ) as validate_mock: |
| 2747 | + state, _, error = run(runtime_ti, context, log) |
| 2748 | + |
| 2749 | + validate_mock.assert_called_once() |
| 2750 | + |
| 2751 | + outlet_event_accessor = listener.outlet_events.pop() |
| 2752 | + assert outlet_event_accessor.key == test_key |
| 2753 | + assert outlet_event_accessor.extra == test_extra |
| 2754 | + |
| 2755 | + finalize(runtime_ti, state, context, log, error) |
| 2756 | + |
| 2757 | + outlet_event_accessor = listener.outlet_events.pop() |
| 2758 | + assert outlet_event_accessor.key == test_key |
| 2759 | + assert outlet_event_accessor.extra == test_extra |
| 2760 | + |
| 2761 | + assert listener.error == error |
| 2762 | + |
2625 | 2763 |
|
2626 | 2764 | @pytest.mark.usefixtures("mock_supervisor_comms") |
2627 | 2765 | class TestTaskRunnerCallsCallbacks: |
|
0 commit comments