|
107 | 107 | TEST_BATCH_JOB_NAME = "test-name" |
108 | 108 | TEST_FILE_NAME = "test-file" |
109 | 109 | TEST_FILE_PATH = "test/path/to/file" |
| 110 | +TEST_CREATE_BATCH_JOB_RESPONSE = { |
| 111 | + "src": None, |
| 112 | + "dest": "test-batch-job-destination", |
| 113 | + "name": "test-name", |
| 114 | + "error": None, |
| 115 | + "model": "test-model", |
| 116 | + "state": "JOB_STATE_SUCCEEDED", |
| 117 | + "end_time": "test-end-datetime", |
| 118 | + "start_time": None, |
| 119 | + "create_time": "test-create-datetime", |
| 120 | + "update_time": "test-update-datetime", |
| 121 | + "display_name": "test-display-name", |
| 122 | + "completion_stats": None, |
| 123 | +} |
110 | 124 |
|
111 | 125 |
|
112 | 126 | def assert_warning(msg: str, warnings): |
@@ -301,6 +315,59 @@ def test_execute(self, mock_hook): |
301 | 315 | create_batch_job_config=None, |
302 | 316 | ) |
303 | 317 |
|
| 318 | + @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook")) |
| 319 | + def test_execute_return_value(self, mock_hook): |
| 320 | + expected_return = TEST_CREATE_BATCH_JOB_RESPONSE |
| 321 | + |
| 322 | + mock_job = mock.MagicMock() |
| 323 | + mock_job.model_dump.return_value = expected_return |
| 324 | + mock_hook.return_value.create_batch_job.return_value = mock_job |
| 325 | + |
| 326 | + op = GenAIGeminiCreateBatchJobOperator( |
| 327 | + task_id=TASK_ID, |
| 328 | + project_id=GCP_PROJECT, |
| 329 | + location=GCP_LOCATION, |
| 330 | + model=TEST_GEMINI_MODEL, |
| 331 | + gcp_conn_id=GCP_CONN_ID, |
| 332 | + impersonation_chain=IMPERSONATION_CHAIN, |
| 333 | + input_source=TEST_BATCH_JOB_INLINED_REQUESTS, |
| 334 | + gemini_api_key=TEST_GEMINI_API_KEY, |
| 335 | + deferrable=False, |
| 336 | + wait_until_complete=False, |
| 337 | + ) |
| 338 | + |
| 339 | + result = op.execute(context={"ti": mock.MagicMock()}) |
| 340 | + |
| 341 | + assert result == expected_return |
| 342 | + mock_job.model_dump.assert_called_once_with(mode="json") |
| 343 | + |
| 344 | + @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook")) |
| 345 | + def test_execute_complete_return_value(self, mock_hook): |
| 346 | + expected_return = TEST_CREATE_BATCH_JOB_RESPONSE |
| 347 | + |
| 348 | + event = {"status": "success", "job_name": "test-name"} |
| 349 | + |
| 350 | + mock_job = mock.MagicMock() |
| 351 | + mock_job.model_dump.return_value = expected_return |
| 352 | + mock_hook.return_value.get_batch_job.return_value = mock_job |
| 353 | + |
| 354 | + op = GenAIGeminiCreateBatchJobOperator( |
| 355 | + task_id=TASK_ID, |
| 356 | + project_id=GCP_PROJECT, |
| 357 | + location=GCP_LOCATION, |
| 358 | + model=TEST_GEMINI_MODEL, |
| 359 | + gcp_conn_id=GCP_CONN_ID, |
| 360 | + impersonation_chain=IMPERSONATION_CHAIN, |
| 361 | + input_source=TEST_BATCH_JOB_INLINED_REQUESTS, |
| 362 | + gemini_api_key=TEST_GEMINI_API_KEY, |
| 363 | + ) |
| 364 | + |
| 365 | + result = op.execute_complete(context={"ti": mock.MagicMock()}, event=event) |
| 366 | + |
| 367 | + assert result == expected_return |
| 368 | + mock_hook.return_value.get_batch_job.assert_called_once_with("test-name") |
| 369 | + mock_job.model_dump.assert_called_once_with(mode="json") |
| 370 | + |
304 | 371 |
|
305 | 372 | class TestGenAIGeminiGetBatchJobOperator: |
306 | 373 | @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook")) |
|
0 commit comments