Skip to content

Commit 73282bd

Browse files
authored
Add wait_for_termination parameter and fix double-deferral in PowerBIDatasetRefreshOperator (#60369)
1 parent b8f3860 commit 73282bd

File tree

2 files changed

+183
-63
lines changed
  • providers/microsoft/azure

2 files changed

+183
-63
lines changed

providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/powerbi.py

Lines changed: 69 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from collections.abc import Sequence
2121
from typing import TYPE_CHECKING, Any
2222

23+
from airflow.configuration import conf
2324
from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, BaseOperatorLink
2425
from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIHook
2526
from airflow.providers.microsoft.azure.triggers.powerbi import (
@@ -61,10 +62,12 @@ class PowerBIDatasetRefreshOperator(BaseOperator):
6162
:param dataset_id: The dataset id.
6263
:param group_id: The workspace id.
6364
:param conn_id: Airflow Connection ID that contains the connection information for the Power BI account used for authentication.
64-
:param timeout: Time in seconds to wait for a dataset to reach a terminal status for asynchronous waits. Used only if ``wait_for_termination`` is True.
65+
:param timeout: Time in seconds to wait for a dataset to reach a terminal status for asynchronous waits. Used only if ``wait_for_completion`` is True.
6566
:param check_interval: Number of seconds to wait before rechecking the
6667
refresh status.
6768
:param request_body: Additional arguments to pass to the request body, as described in https://learn.microsoft.com/en-us/rest/api/power-bi/datasets/refresh-dataset-in-group#request-body.
69+
:param wait_for_completion: If True, wait for the dataset refresh to complete. If False, trigger the refresh and return immediately without waiting.
70+
:param deferrable: This parameter is deprecated and no longer has any effect. The operator now always uses deferrable execution when ``wait_for_completion=True``.
6871
"""
6972

7073
template_fields: Sequence[str] = (
@@ -86,13 +89,19 @@ def __init__(
8689
api_version: APIVersion | str | None = None,
8790
check_interval: int = 60,
8891
request_body: dict[str, Any] | None = None,
92+
wait_for_completion: bool = True,
93+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
8994
**kwargs,
9095
) -> None:
9196
super().__init__(**kwargs)
97+
if "deferrable" in kwargs or deferrable is not True:
98+
self.log.warning(
99+
"The PowerBIDatasetRefreshOperator now always uses deferrable execution when wait_for_completion=True."
100+
)
92101
self.hook = PowerBIHook(conn_id=conn_id, proxies=proxies, api_version=api_version, timeout=timeout)
93102
self.dataset_id = dataset_id
94103
self.group_id = group_id
95-
self.wait_for_termination = True
104+
self.wait_for_completion = wait_for_completion
96105
self.conn_id = conn_id
97106
self.timeout = timeout
98107
self.check_interval = check_interval
@@ -108,63 +117,76 @@ def api_version(self) -> str | None:
108117

109118
def execute(self, context: Context):
110119
"""Refresh the Power BI Dataset."""
111-
if self.wait_for_termination:
112-
self.defer(
113-
trigger=PowerBITrigger(
114-
conn_id=self.conn_id,
115-
group_id=self.group_id,
116-
dataset_id=self.dataset_id,
117-
timeout=self.timeout,
118-
proxies=self.proxies,
119-
api_version=self.api_version,
120-
check_interval=self.check_interval,
121-
wait_for_termination=self.wait_for_termination,
122-
request_body=self.request_body,
123-
),
124-
method_name=self.get_refresh_status.__name__,
120+
if not self.wait_for_completion:
121+
# Fire and forget - synchronous execution, no deferral
122+
hook = PowerBIHook(
123+
conn_id=self.conn_id, proxies=self.proxies, api_version=self.api_version, timeout=self.timeout
125124
)
126125

127-
def get_refresh_status(self, context: Context, event: dict[str, str] | None = None):
128-
"""Push the refresh Id to XCom then runs the Trigger to wait for refresh completion."""
129-
if event:
130-
if event["status"] == "error":
131-
raise AirflowException(event["message"])
126+
dataset_refresh_id = hook.trigger_dataset_refresh(
127+
dataset_id=self.dataset_id,
128+
group_id=self.group_id,
129+
request_body=self.request_body,
130+
)
132131

133-
dataset_refresh_id = event["dataset_refresh_id"]
132+
if dataset_refresh_id:
133+
self.log.info("Triggered dataset refresh %s (fire-and-forget)", dataset_refresh_id)
134+
context["ti"].xcom_push(
135+
key=f"{self.task_id}.powerbi_dataset_refresh_id",
136+
value=dataset_refresh_id,
137+
)
138+
else:
139+
raise AirflowException("Failed to trigger dataset refresh")
140+
return
141+
142+
# Wait for termination - use deferrable trigger
143+
self.defer(
144+
trigger=PowerBITrigger(
145+
conn_id=self.conn_id,
146+
group_id=self.group_id,
147+
dataset_id=self.dataset_id,
148+
timeout=self.timeout,
149+
proxies=self.proxies,
150+
api_version=self.api_version,
151+
check_interval=self.check_interval,
152+
wait_for_termination=self.wait_for_completion,
153+
request_body=self.request_body,
154+
),
155+
method_name=self.execute_complete.__name__,
156+
)
157+
158+
def execute_complete(self, context: Context, event: dict[str, str]) -> None:
159+
"""
160+
Handle trigger completion and push results to XCom or raise an exception.
161+
162+
:param context: Airflow context dictionary
163+
:param event: Event dict from trigger with status and dataset_refresh_id
164+
"""
165+
if not event:
166+
return
167+
168+
# Success - push both ID and status to XCom
169+
dataset_refresh_id = event.get("dataset_refresh_id")
170+
dataset_refresh_status = event.get("dataset_refresh_status")
134171

135172
if dataset_refresh_id:
136173
context["ti"].xcom_push(
137-
key=f"{self.task_id}.powerbi_dataset_refresh_Id",
174+
key=f"{self.task_id}.powerbi_dataset_refresh_id",
138175
value=dataset_refresh_id,
139176
)
140-
self.defer(
141-
trigger=PowerBITrigger(
142-
conn_id=self.conn_id,
143-
group_id=self.group_id,
144-
dataset_id=self.dataset_id,
145-
dataset_refresh_id=dataset_refresh_id,
146-
timeout=self.timeout,
147-
proxies=self.proxies,
148-
api_version=self.api_version,
149-
check_interval=self.check_interval,
150-
wait_for_termination=self.wait_for_termination,
151-
),
152-
method_name=self.execute_complete.__name__,
153-
)
154-
155-
def execute_complete(self, context: Context, event: dict[str, str]) -> Any:
156-
"""
157-
Return immediately - callback for when the trigger fires.
158177

159-
Relies on trigger to throw an exception, otherwise it assumes execution was successful.
160-
"""
161-
if event:
178+
if dataset_refresh_status:
162179
context["ti"].xcom_push(
163180
key=f"{self.task_id}.powerbi_dataset_refresh_status",
164-
value=event["dataset_refresh_status"],
181+
value=dataset_refresh_status,
165182
)
166-
if event["status"] == "error":
167-
raise AirflowException(event["message"])
183+
184+
if event["status"] == "error":
185+
raise AirflowException(event["message"])
186+
187+
self.log.info(
188+
"Dataset refresh %s completed with status: %s", dataset_refresh_id, dataset_refresh_status
189+
)
168190

169191

170192
class PowerBIWorkspaceListOperator(BaseOperator):

providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_powerbi.py

Lines changed: 114 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898

9999
class TestPowerBIDatasetRefreshOperator:
100100
@mock.patch.object(BaseHook, "get_connection", side_effect=get_airflow_connection)
101-
def test_execute_wait_for_termination_with_deferrable(self, connection):
101+
def test_execute_wait_for_completion_with_deferrable(self, connection):
102102
operator = PowerBIDatasetRefreshOperator(
103103
**CONFIG,
104104
)
@@ -112,34 +112,52 @@ def test_execute_wait_for_termination_with_deferrable(self, connection):
112112

113113
@mock.patch.object(BaseHook, "get_connection", side_effect=get_airflow_connection)
114114
def test_powerbi_operator_async_get_refresh_status_success(self, connection):
115-
"""Assert that get_refresh_status log success message"""
115+
"""Test that execute defers once when wait_for_completion=True"""
116116
operator = PowerBIDatasetRefreshOperator(
117117
**CONFIG,
118+
wait_for_completion=True, # Explicitly set to True
118119
)
119-
context = {"ti": MagicMock()}
120-
context["ti"].task_id = TASK_ID
120+
context = mock_context(task=operator)
121121

122122
with pytest.raises(TaskDeferred) as exc:
123-
operator.get_refresh_status(
124-
context=context,
125-
event=SUCCESS_TRIGGER_EVENT,
126-
)
123+
operator.execute(context)
127124

125+
# Verify trigger is correct type
128126
assert isinstance(exc.value.trigger, PowerBITrigger)
129-
assert exc.value.trigger.dataset_refresh_id is NEW_REFRESH_REQUEST_ID
130-
assert context["ti"].xcom_push.call_count == 1
127+
128+
# Verify trigger has correct parameters
129+
assert exc.value.trigger.dataset_id == DATASET_ID
130+
assert exc.value.trigger.group_id == GROUP_ID
131+
assert exc.value.trigger.wait_for_termination is True
132+
133+
# Verify callback method name
134+
assert exc.value.method_name == "execute_complete"
135+
136+
# Verify dataset_refresh_id is None (trigger will create it)
137+
assert exc.value.trigger.dataset_refresh_id is None
131138

132139
def test_powerbi_operator_async_execute_complete_success(self):
133-
"""Assert that execute_complete log success message"""
134-
operator = PowerBIDatasetRefreshOperator(
135-
**CONFIG,
136-
)
140+
"""Assert that execute_complete processes success event correctly"""
141+
operator = PowerBIDatasetRefreshOperator(**CONFIG)
137142
context = {"ti": MagicMock()}
143+
138144
operator.execute_complete(
139145
context=context,
140146
event=SUCCESS_REFRESH_EVENT,
141147
)
142-
assert context["ti"].xcom_push.call_count == 1
148+
149+
# Should push both refresh_id and status
150+
assert context["ti"].xcom_push.call_count == 2
151+
152+
# Verify the XCom keys and values
153+
calls = context["ti"].xcom_push.call_args_list
154+
xcom_data = {call[1]["key"]: call[1]["value"] for call in calls}
155+
156+
assert f"{TASK_ID}.powerbi_dataset_refresh_id" in xcom_data
157+
assert xcom_data[f"{TASK_ID}.powerbi_dataset_refresh_id"] == NEW_REFRESH_REQUEST_ID
158+
159+
assert f"{TASK_ID}.powerbi_dataset_refresh_status" in xcom_data
160+
assert xcom_data[f"{TASK_ID}.powerbi_dataset_refresh_status"] == PowerBIDatasetRefreshStatus.COMPLETED
143161

144162
def test_powerbi_operator_async_execute_complete_fail(self):
145163
"""Assert that execute_complete raise exception on error"""
@@ -176,7 +194,7 @@ def test_powerbi_operator_refresh_fail(self):
176194
"dataset_refresh_id": "1234",
177195
},
178196
)
179-
assert context["ti"].xcom_push.call_count == 1
197+
assert context["ti"].xcom_push.call_count == 2
180198
assert str(exc.value) == "error message"
181199

182200
def test_execute_complete_no_event(self):
@@ -213,3 +231,83 @@ def test_powerbi_link(self, dag_maker, create_task_instance_of_operator):
213231
)
214232

215233
assert url == EXPECTED_ITEM_RUN_OP_EXTRA_LINK
234+
235+
@mock.patch("airflow.providers.microsoft.azure.operators.powerbi.PowerBIHook")
236+
@mock.patch.object(BaseHook, "get_connection", side_effect=get_airflow_connection)
237+
def test_execute_fire_and_forget_mode(self, mock_connection, mock_hook_class):
238+
"""Test fire-and-forget mode (wait_for_completion=False)"""
239+
mock_hook_instance = mock_hook_class.return_value
240+
mock_hook_instance.trigger_dataset_refresh.return_value = NEW_REFRESH_REQUEST_ID
241+
242+
operator = PowerBIDatasetRefreshOperator(
243+
**CONFIG,
244+
wait_for_completion=False,
245+
)
246+
context = {"ti": MagicMock()}
247+
context["ti"].task_id = TASK_ID
248+
249+
# Should not raise TaskDeferred
250+
result = operator.execute(context)
251+
252+
# Verify hook was called correctly
253+
mock_hook_instance.trigger_dataset_refresh.assert_called_once_with(
254+
dataset_id=DATASET_ID,
255+
group_id=GROUP_ID,
256+
request_body=REQUEST_BODY,
257+
)
258+
259+
# Verify XCom push
260+
assert context["ti"].xcom_push.call_count == 1
261+
call_args = context["ti"].xcom_push.call_args
262+
assert call_args[1]["key"] == f"{TASK_ID}.powerbi_dataset_refresh_id"
263+
assert call_args[1]["value"] == NEW_REFRESH_REQUEST_ID
264+
265+
# Should return None (completes synchronously)
266+
assert result is None
267+
268+
@mock.patch("airflow.providers.microsoft.azure.operators.powerbi.PowerBIHook")
269+
@mock.patch.object(BaseHook, "get_connection", side_effect=get_airflow_connection)
270+
def test_execute_fire_and_forget_mode_failure(self, mock_connection, mock_hook_class):
271+
"""Test fire-and-forget mode raises exception when trigger fails"""
272+
mock_hook_instance = mock_hook_class.return_value
273+
mock_hook_instance.trigger_dataset_refresh.return_value = None
274+
275+
operator = PowerBIDatasetRefreshOperator(
276+
**CONFIG,
277+
wait_for_completion=False,
278+
)
279+
context = {"ti": MagicMock()}
280+
context["ti"].task_id = TASK_ID
281+
282+
# Should raise AirflowException
283+
with pytest.raises(AirflowException, match="Failed to trigger dataset refresh"):
284+
operator.execute(context)
285+
286+
# Should not push to XCom on failure
287+
assert context["ti"].xcom_push.call_count == 0
288+
289+
@mock.patch.object(BaseHook, "get_connection", side_effect=get_airflow_connection)
290+
def test_execute_default_behavior_waits_for_completion(self, mock_connection):
291+
"""Test that default behavior (wait_for_completion=True) defers and waits"""
292+
config_without_wait = {
293+
"task_id": TASK_ID,
294+
"conn_id": DEFAULT_CONNECTION_CLIENT_SECRET,
295+
"group_id": GROUP_ID,
296+
"dataset_id": DATASET_ID,
297+
"request_body": REQUEST_BODY,
298+
"check_interval": 1,
299+
"timeout": 3,
300+
# Deliberately exclude wait_for_completion - should default to True
301+
}
302+
303+
operator = PowerBIDatasetRefreshOperator(**config_without_wait)
304+
context = mock_context(task=operator)
305+
306+
# Should defer (because default is wait_for_completion=True)
307+
with pytest.raises(TaskDeferred) as exc:
308+
operator.execute(context)
309+
310+
# Verify it deferred with correct trigger
311+
assert isinstance(exc.value.trigger, PowerBITrigger)
312+
assert exc.value.trigger.wait_for_termination is True
313+
assert exc.value.method_name == "execute_complete"

0 commit comments

Comments
 (0)