Skip to content

Commit a2581cc

Browse files
fix dataproc trigger (#53485)
# Conflicts: # providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
1 parent 0f0ea3b commit a2581cc

File tree

3 files changed

+58
-45
lines changed

3 files changed

+58
-45
lines changed

providers/google/src/airflow/providers/google/cloud/operators/dataproc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
907907
cluster_state = event["cluster_state"]
908908
cluster_name = event["cluster_name"]
909909

910-
if cluster_state == ClusterStatus.State.ERROR:
910+
if cluster_state == ClusterStatus.State(ClusterStatus.State.DELETING).name:
911911
raise AirflowException(f"Cluster is in ERROR state:\n{cluster_name}")
912912

913913
self.log.info("%s completed successfully.", self.task_id)

providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -316,23 +316,24 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
316316
yield TriggerEvent(
317317
{
318318
"cluster_name": self.cluster_name,
319-
"cluster_state": ClusterStatus.State.DELETING,
320-
"cluster": cluster,
319+
"cluster_state": ClusterStatus.State(ClusterStatus.State.DELETING).name,
320+
"cluster": Cluster.to_dict(cluster),
321321
}
322322
)
323323
return
324324
elif state == ClusterStatus.State.RUNNING:
325325
yield TriggerEvent(
326326
{
327327
"cluster_name": self.cluster_name,
328-
"cluster_state": state,
329-
"cluster": cluster,
328+
"cluster_state": ClusterStatus.State(state).name,
329+
"cluster": Cluster.to_dict(cluster),
330330
}
331331
)
332332
return
333-
self.log.info("Current state is %s", state)
334-
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
335-
await asyncio.sleep(self.polling_interval_seconds)
333+
else:
334+
self.log.info("Current state is %s", state)
335+
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
336+
await asyncio.sleep(self.polling_interval_seconds)
336337
except asyncio.CancelledError:
337338
try:
338339
if self.delete_on_error and await self.safe_to_cancel():

providers/google/tests/unit/google/cloud/triggers/test_dataproc.py

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import asyncio
2020
import contextlib
21+
import logging
2122
from asyncio import CancelledError, Future, sleep
2223
from unittest import mock
2324

@@ -50,6 +51,14 @@
5051
TEST_GCP_CONN_ID = "google_cloud_default"
5152
TEST_OPERATION_NAME = "name"
5253
TEST_JOB_ID = "test-job-id"
54+
TEST_RUNNING_CLUSTER = Cluster(
55+
cluster_name=TEST_CLUSTER_NAME,
56+
status=ClusterStatus(state=ClusterStatus.State.RUNNING),
57+
)
58+
TEST_ERROR_CLUSTER = Cluster(
59+
cluster_name=TEST_CLUSTER_NAME,
60+
status=ClusterStatus(state=ClusterStatus.State.ERROR),
61+
)
5362

5463

5564
@pytest.fixture
@@ -158,28 +167,56 @@ def test_async_cluster_trigger_serialization_should_execute_successfully(self, c
158167
@pytest.mark.db_test
159168
@pytest.mark.asyncio
160169
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
161-
@mock.patch.object(DataprocClusterTrigger, "log")
170+
async def test_async_cluster_triggers_on_success_should_execute_successfully(
171+
self, mock_get_async_hook, cluster_trigger
172+
):
173+
future = asyncio.Future()
174+
future.set_result(TEST_RUNNING_CLUSTER)
175+
mock_get_async_hook.return_value.get_cluster.return_value = future
176+
177+
generator = cluster_trigger.run()
178+
actual_event = await generator.asend(None)
179+
180+
expected_event = TriggerEvent(
181+
{
182+
"cluster_name": TEST_CLUSTER_NAME,
183+
"cluster_state": ClusterStatus.State(ClusterStatus.State.RUNNING).name,
184+
"cluster": actual_event.payload["cluster"],
185+
}
186+
)
187+
assert expected_event == actual_event
188+
189+
@pytest.mark.db_test
190+
@pytest.mark.asyncio
191+
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.fetch_cluster")
192+
@mock.patch(
193+
"airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster",
194+
return_value=asyncio.Future(),
195+
)
196+
@mock.patch("google.auth.default")
162197
async def test_async_cluster_trigger_run_returns_error_event(
163-
self, mock_log, mock_get_async_hook, cluster_trigger
198+
self, mock_auth, mock_delete_cluster, mock_fetch_cluster, cluster_trigger, async_get_cluster, caplog
164199
):
165-
# Mock delete_cluster to return a Future
166-
mock_delete_future = asyncio.Future()
167-
mock_delete_future.set_result(None)
168-
mock_get_async_hook.return_value.delete_cluster.return_value = mock_delete_future
200+
mock_credentials = mock.MagicMock()
201+
mock_credentials.universe_domain = "googleapis.com"
169202

170-
mock_cluster = mock.MagicMock()
171-
mock_cluster.status = ClusterStatus(state=ClusterStatus.State.ERROR)
203+
mock_auth.return_value = (mock_credentials, "project-id")
172204

173-
future = asyncio.Future()
174-
future.set_result(mock_cluster)
175-
mock_get_async_hook.return_value.get_cluster.return_value = future
205+
mock_delete_cluster.return_value = asyncio.Future()
206+
mock_delete_cluster.return_value.set_result(None)
207+
208+
mock_fetch_cluster.return_value = TEST_ERROR_CLUSTER
209+
210+
caplog.set_level(logging.INFO)
176211

177212
trigger_event = None
178213
async for event in cluster_trigger.run():
179214
trigger_event = event
180215

181216
assert trigger_event.payload["cluster_name"] == TEST_CLUSTER_NAME
182-
assert trigger_event.payload["cluster_state"] == ClusterStatus.State.DELETING
217+
assert (
218+
trigger_event.payload["cluster_state"] == ClusterStatus.State(ClusterStatus.State.DELETING).name
219+
)
183220

184221
@pytest.mark.db_test
185222
@pytest.mark.asyncio
@@ -321,31 +358,6 @@ async def test_cluster_trigger_run_cancelled_not_safe_to_cancel(
321358
assert mock_delete_cluster.call_count == 0
322359
mock_delete_cluster.assert_not_called()
323360

324-
@pytest.mark.db_test
325-
@pytest.mark.asyncio
326-
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
327-
async def test_async_cluster_triggers_on_success_should_execute_successfully(
328-
self, mock_get_async_hook, cluster_trigger
329-
):
330-
mock_cluster = mock.MagicMock()
331-
mock_cluster.status = ClusterStatus(state=ClusterStatus.State.RUNNING)
332-
333-
future = asyncio.Future()
334-
future.set_result(mock_cluster)
335-
mock_get_async_hook.return_value.get_cluster.return_value = future
336-
337-
generator = cluster_trigger.run()
338-
actual_event = await generator.asend(None)
339-
340-
expected_event = TriggerEvent(
341-
{
342-
"cluster_name": TEST_CLUSTER_NAME,
343-
"cluster_state": ClusterStatus.State.RUNNING,
344-
"cluster": actual_event.payload["cluster"],
345-
}
346-
)
347-
assert expected_event == actual_event
348-
349361

350362
class TestDataprocBatchTrigger:
351363
def test_async_create_batch_trigger_serialization_should_execute_successfully(self, batch_trigger):

0 commit comments

Comments
 (0)