|
18 | 18 |
|
19 | 19 | import asyncio |
20 | 20 | import contextlib |
| 21 | +import logging |
21 | 22 | from asyncio import CancelledError, Future, sleep |
22 | 23 | from unittest import mock |
23 | 24 |
|
|
50 | 51 | TEST_GCP_CONN_ID = "google_cloud_default" |
51 | 52 | TEST_OPERATION_NAME = "name" |
52 | 53 | TEST_JOB_ID = "test-job-id" |
| 54 | +TEST_RUNNING_CLUSTER = Cluster( |
| 55 | + cluster_name=TEST_CLUSTER_NAME, |
| 56 | + status=ClusterStatus(state=ClusterStatus.State.RUNNING), |
| 57 | +) |
| 58 | +TEST_ERROR_CLUSTER = Cluster( |
| 59 | + cluster_name=TEST_CLUSTER_NAME, |
| 60 | + status=ClusterStatus(state=ClusterStatus.State.ERROR), |
| 61 | +) |
53 | 62 |
|
54 | 63 |
|
55 | 64 | @pytest.fixture |
@@ -158,28 +167,56 @@ def test_async_cluster_trigger_serialization_should_execute_successfully(self, c |
158 | 167 | @pytest.mark.db_test |
159 | 168 | @pytest.mark.asyncio |
160 | 169 | @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook") |
161 | | - @mock.patch.object(DataprocClusterTrigger, "log") |
| 170 | + async def test_async_cluster_triggers_on_success_should_execute_successfully( |
| 171 | + self, mock_get_async_hook, cluster_trigger |
| 172 | + ): |
| 173 | + future = asyncio.Future() |
| 174 | + future.set_result(TEST_RUNNING_CLUSTER) |
| 175 | + mock_get_async_hook.return_value.get_cluster.return_value = future |
| 176 | + |
| 177 | + generator = cluster_trigger.run() |
| 178 | + actual_event = await generator.asend(None) |
| 179 | + |
| 180 | + expected_event = TriggerEvent( |
| 181 | + { |
| 182 | + "cluster_name": TEST_CLUSTER_NAME, |
| 183 | + "cluster_state": ClusterStatus.State(ClusterStatus.State.RUNNING).name, |
| 184 | + "cluster": actual_event.payload["cluster"], |
| 185 | + } |
| 186 | + ) |
| 187 | + assert expected_event == actual_event |
| 188 | + |
| 189 | + @pytest.mark.db_test |
| 190 | + @pytest.mark.asyncio |
| 191 | + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.fetch_cluster") |
| 192 | + @mock.patch( |
| 193 | + "airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster", |
| 194 | + return_value=asyncio.Future(), |
| 195 | + ) |
| 196 | + @mock.patch("google.auth.default") |
162 | 197 | async def test_async_cluster_trigger_run_returns_error_event( |
163 | | - self, mock_log, mock_get_async_hook, cluster_trigger |
| 198 | + self, mock_auth, mock_delete_cluster, mock_fetch_cluster, cluster_trigger, async_get_cluster, caplog |
164 | 199 | ): |
165 | | - # Mock delete_cluster to return a Future |
166 | | - mock_delete_future = asyncio.Future() |
167 | | - mock_delete_future.set_result(None) |
168 | | - mock_get_async_hook.return_value.delete_cluster.return_value = mock_delete_future |
| 200 | + mock_credentials = mock.MagicMock() |
| 201 | + mock_credentials.universe_domain = "googleapis.com" |
169 | 202 |
|
170 | | - mock_cluster = mock.MagicMock() |
171 | | - mock_cluster.status = ClusterStatus(state=ClusterStatus.State.ERROR) |
| 203 | + mock_auth.return_value = (mock_credentials, "project-id") |
172 | 204 |
|
173 | | - future = asyncio.Future() |
174 | | - future.set_result(mock_cluster) |
175 | | - mock_get_async_hook.return_value.get_cluster.return_value = future |
| 205 | + mock_delete_cluster.return_value = asyncio.Future() |
| 206 | + mock_delete_cluster.return_value.set_result(None) |
| 207 | + |
| 208 | + mock_fetch_cluster.return_value = TEST_ERROR_CLUSTER |
| 209 | + |
| 210 | + caplog.set_level(logging.INFO) |
176 | 211 |
|
177 | 212 | trigger_event = None |
178 | 213 | async for event in cluster_trigger.run(): |
179 | 214 | trigger_event = event |
180 | 215 |
|
181 | 216 | assert trigger_event.payload["cluster_name"] == TEST_CLUSTER_NAME |
182 | | - assert trigger_event.payload["cluster_state"] == ClusterStatus.State.DELETING |
| 217 | + assert ( |
| 218 | + trigger_event.payload["cluster_state"] == ClusterStatus.State(ClusterStatus.State.DELETING).name |
| 219 | + ) |
183 | 220 |
|
184 | 221 | @pytest.mark.db_test |
185 | 222 | @pytest.mark.asyncio |
@@ -321,31 +358,6 @@ async def test_cluster_trigger_run_cancelled_not_safe_to_cancel( |
321 | 358 | assert mock_delete_cluster.call_count == 0 |
322 | 359 | mock_delete_cluster.assert_not_called() |
323 | 360 |
|
324 | | - @pytest.mark.db_test |
325 | | - @pytest.mark.asyncio |
326 | | - @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook") |
327 | | - async def test_async_cluster_triggers_on_success_should_execute_successfully( |
328 | | - self, mock_get_async_hook, cluster_trigger |
329 | | - ): |
330 | | - mock_cluster = mock.MagicMock() |
331 | | - mock_cluster.status = ClusterStatus(state=ClusterStatus.State.RUNNING) |
332 | | - |
333 | | - future = asyncio.Future() |
334 | | - future.set_result(mock_cluster) |
335 | | - mock_get_async_hook.return_value.get_cluster.return_value = future |
336 | | - |
337 | | - generator = cluster_trigger.run() |
338 | | - actual_event = await generator.asend(None) |
339 | | - |
340 | | - expected_event = TriggerEvent( |
341 | | - { |
342 | | - "cluster_name": TEST_CLUSTER_NAME, |
343 | | - "cluster_state": ClusterStatus.State.RUNNING, |
344 | | - "cluster": actual_event.payload["cluster"], |
345 | | - } |
346 | | - ) |
347 | | - assert expected_event == actual_event |
348 | | - |
349 | 361 |
|
350 | 362 | class TestDataprocBatchTrigger: |
351 | 363 | def test_async_create_batch_trigger_serialization_should_execute_successfully(self, batch_trigger): |
|
0 commit comments