Skip to content

Commit 1129231

Browse files
CrowiantAnton Nitochkin
andauthored
Change serialization inside GenAIGeminiCreateBatchJobOperator (#61253)
Co-authored-by: Anton Nitochkin <nitochkin@google.com>
1 parent 8c96236 commit 1129231

File tree

2 files changed

+69
-2
lines changed
  • providers/google

2 files changed

+69
-2
lines changed

providers/google/src/airflow/providers/google/cloud/operators/gen_ai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def execute(self, context: Context):
578578
job_results = self._prepare_results_for_xcom(job)
579579
context["ti"].xcom_push(key="job_results", value=job_results)
580580

581-
return dict(job)
581+
return job.model_dump(mode="json")
582582

583583
def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any]:
584584
if event["status"] == "error":
@@ -588,7 +588,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str,
588588
if self.retrieve_result and job.error is None:
589589
job_results = self._prepare_results_for_xcom(job)
590590
context["ti"].xcom_push(key="job_results", value=job_results)
591-
return dict(job)
591+
return job.model_dump(mode="json")
592592

593593

594594
class GenAIGeminiGetBatchJobOperator(GoogleCloudBaseOperator):

providers/google/tests/unit/google/cloud/operators/test_gen_ai.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,20 @@
107107
TEST_BATCH_JOB_NAME = "test-name"
108108
TEST_FILE_NAME = "test-file"
109109
TEST_FILE_PATH = "test/path/to/file"
110+
TEST_CREATE_BATCH_JOB_RESPONSE = {
111+
"src": None,
112+
"dest": "test-batch-job-destination",
113+
"name": "test-name",
114+
"error": None,
115+
"model": "test-model",
116+
"state": "JOB_STATE_SUCCEEDED",
117+
"end_time": "test-end-datetime",
118+
"start_time": None,
119+
"create_time": "test-create-datetime",
120+
"update_time": "test-update-datetime",
121+
"display_name": "test-display-name",
122+
"completion_stats": None,
123+
}
110124

111125

112126
def assert_warning(msg: str, warnings):
@@ -301,6 +315,59 @@ def test_execute(self, mock_hook):
301315
create_batch_job_config=None,
302316
)
303317

318+
@mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
319+
def test_execute_return_value(self, mock_hook):
320+
expected_return = TEST_CREATE_BATCH_JOB_RESPONSE
321+
322+
mock_job = mock.MagicMock()
323+
mock_job.model_dump.return_value = expected_return
324+
mock_hook.return_value.create_batch_job.return_value = mock_job
325+
326+
op = GenAIGeminiCreateBatchJobOperator(
327+
task_id=TASK_ID,
328+
project_id=GCP_PROJECT,
329+
location=GCP_LOCATION,
330+
model=TEST_GEMINI_MODEL,
331+
gcp_conn_id=GCP_CONN_ID,
332+
impersonation_chain=IMPERSONATION_CHAIN,
333+
input_source=TEST_BATCH_JOB_INLINED_REQUESTS,
334+
gemini_api_key=TEST_GEMINI_API_KEY,
335+
deferrable=False,
336+
wait_until_complete=False,
337+
)
338+
339+
result = op.execute(context={"ti": mock.MagicMock()})
340+
341+
assert result == expected_return
342+
mock_job.model_dump.assert_called_once_with(mode="json")
343+
344+
@mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
345+
def test_execute_complete_return_value(self, mock_hook):
346+
expected_return = TEST_CREATE_BATCH_JOB_RESPONSE
347+
348+
event = {"status": "success", "job_name": "test-name"}
349+
350+
mock_job = mock.MagicMock()
351+
mock_job.model_dump.return_value = expected_return
352+
mock_hook.return_value.get_batch_job.return_value = mock_job
353+
354+
op = GenAIGeminiCreateBatchJobOperator(
355+
task_id=TASK_ID,
356+
project_id=GCP_PROJECT,
357+
location=GCP_LOCATION,
358+
model=TEST_GEMINI_MODEL,
359+
gcp_conn_id=GCP_CONN_ID,
360+
impersonation_chain=IMPERSONATION_CHAIN,
361+
input_source=TEST_BATCH_JOB_INLINED_REQUESTS,
362+
gemini_api_key=TEST_GEMINI_API_KEY,
363+
)
364+
365+
result = op.execute_complete(context={"ti": mock.MagicMock()}, event=event)
366+
367+
assert result == expected_return
368+
mock_hook.return_value.get_batch_job.assert_called_once_with("test-name")
369+
mock_job.model_dump.assert_called_once_with(mode="json")
370+
304371

305372
class TestGenAIGeminiGetBatchJobOperator:
306373
@mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))

0 commit comments

Comments
 (0)