|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | | - |
15 | | -# |
16 | | -# Using Google Cloud Vertex AI to test the code samples. |
17 | | -# |
18 | | -from datetime import datetime as dt |
19 | | -import os |
20 | | - |
21 | 14 | from unittest.mock import MagicMock, patch |
22 | 15 |
|
23 | | -from google.cloud import bigquery, storage |
24 | 16 | from google.genai import types |
25 | 17 | from google.genai.types import JobState |
26 | | -import pytest |
27 | 18 |
|
28 | 19 | import batchpredict_embeddings_with_gcs |
29 | 20 | import batchpredict_with_bq |
30 | 21 | import batchpredict_with_gcs |
31 | 22 | import get_batch_job |
32 | 23 |
|
33 | 24 |
|
34 | | -os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "True" |
35 | | -os.environ["GOOGLE_CLOUD_LOCATION"] = "us-central1" |
36 | | -# The project name is included in the CICD pipeline |
37 | | -# os.environ['GOOGLE_CLOUD_PROJECT'] = "add-your-project-name" |
38 | | -BQ_OUTPUT_DATASET = f"{os.environ['GOOGLE_CLOUD_PROJECT']}.gen_ai_batch_prediction" |
39 | | -GCS_OUTPUT_BUCKET = "python-docs-samples-tests" |
40 | | - |
41 | | - |
42 | | -@pytest.fixture(scope="session") |
43 | | -def bq_output_uri() -> str: |
44 | | - table_name = f"text_output_{dt.now().strftime('%Y_%m_%d_T%H_%M_%S')}" |
45 | | - table_uri = f"{BQ_OUTPUT_DATASET}.{table_name}" |
| 25 | +@patch("google.genai.Client") |
| 26 | +@patch("time.sleep", return_value=None) |
| 27 | +def test_batch_prediction_embeddings_with_gcs( |
| 28 | + mock_sleep: MagicMock, mock_genai_client: MagicMock |
| 29 | +) -> None: |
| 30 | + # Mock the API response |
| 31 | + mock_batch_job_running = types.BatchJob( |
| 32 | + name="test-batch-job", state="JOB_STATE_RUNNING" |
| 33 | + ) |
| 34 | + mock_batch_job_succeeded = types.BatchJob( |
| 35 | + name="test-batch-job", state="JOB_STATE_SUCCEEDED" |
| 36 | + ) |
46 | 37 |
|
47 | | - yield f"bq://{table_uri}" |
| 38 | + mock_genai_client.return_value.batches.create.return_value = ( |
| 39 | + mock_batch_job_running |
| 40 | + ) |
| 41 | + mock_genai_client.return_value.batches.get.return_value = ( |
| 42 | + mock_batch_job_succeeded |
| 43 | + ) |
48 | 44 |
|
49 | | - bq_client = bigquery.Client() |
50 | | - bq_client.delete_table(table_uri, not_found_ok=True) |
| 45 | + response = batchpredict_embeddings_with_gcs.generate_content( |
| 46 | + output_uri="gs://test-bucket/test-prefix" |
| 47 | + ) |
51 | 48 |
|
| 49 | + mock_genai_client.assert_called_once_with( |
| 50 | + http_options=types.HttpOptions(api_version="v1") |
| 51 | + ) |
| 52 | + mock_genai_client.return_value.batches.create.assert_called_once() |
| 53 | + mock_genai_client.return_value.batches.get.assert_called_once() |
| 54 | + assert response == JobState.JOB_STATE_SUCCEEDED |
52 | 55 |
|
53 | | -@pytest.fixture(scope="session") |
54 | | -def gcs_output_uri() -> str: |
55 | | - prefix = f"text_output/{dt.now()}" |
56 | 56 |
|
57 | | - yield f"gs://{GCS_OUTPUT_BUCKET}/{prefix}" |
| 57 | +@patch("google.genai.Client") |
| 58 | +@patch("time.sleep", return_value=None) |
| 59 | +def test_batch_prediction_with_bq( |
| 60 | + mock_sleep: MagicMock, mock_genai_client: MagicMock |
| 61 | +) -> None: |
| 62 | + # Mock the API response |
| 63 | + mock_batch_job_running = types.BatchJob( |
| 64 | + name="test-batch-job", state="JOB_STATE_RUNNING" |
| 65 | + ) |
| 66 | + mock_batch_job_succeeded = types.BatchJob( |
| 67 | + name="test-batch-job", state="JOB_STATE_SUCCEEDED" |
| 68 | + ) |
58 | 69 |
|
59 | | - storage_client = storage.Client() |
60 | | - bucket = storage_client.get_bucket(GCS_OUTPUT_BUCKET) |
61 | | - blobs = bucket.list_blobs(prefix=prefix) |
62 | | - for blob in blobs: |
63 | | - blob.delete() |
| 70 | + mock_genai_client.return_value.batches.create.return_value = ( |
| 71 | + mock_batch_job_running |
| 72 | + ) |
| 73 | + mock_genai_client.return_value.batches.get.return_value = ( |
| 74 | + mock_batch_job_succeeded |
| 75 | + ) |
64 | 76 |
|
| 77 | + response = batchpredict_with_bq.generate_content( |
| 78 | + output_uri="bq://test-project.test_dataset.test_table" |
| 79 | + ) |
65 | 80 |
|
66 | | -def test_batch_prediction_embeddings_with_gcs(gcs_output_uri: str) -> None: |
67 | | - response = batchpredict_embeddings_with_gcs.generate_content( |
68 | | - output_uri=gcs_output_uri |
| 81 | + mock_genai_client.assert_called_once_with( |
| 82 | + http_options=types.HttpOptions(api_version="v1") |
69 | 83 | ) |
| 84 | + mock_genai_client.return_value.batches.create.assert_called_once() |
| 85 | + mock_genai_client.return_value.batches.get.assert_called_once() |
70 | 86 | assert response == JobState.JOB_STATE_SUCCEEDED |
71 | 87 |
|
72 | 88 |
|
73 | | -def test_batch_prediction_with_bq(bq_output_uri: str) -> None: |
74 | | - response = batchpredict_with_bq.generate_content(output_uri=bq_output_uri) |
75 | | - assert response == JobState.JOB_STATE_SUCCEEDED |
| 89 | +@patch("google.genai.Client") |
| 90 | +@patch("time.sleep", return_value=None) |
| 91 | +def test_batch_prediction_with_gcs( |
| 92 | + mock_sleep: MagicMock, mock_genai_client: MagicMock |
| 93 | +) -> None: |
| 94 | + # Mock the API response |
| 95 | + mock_batch_job_running = types.BatchJob( |
| 96 | + name="test-batch-job", state="JOB_STATE_RUNNING" |
| 97 | + ) |
| 98 | + mock_batch_job_succeeded = types.BatchJob( |
| 99 | + name="test-batch-job", state="JOB_STATE_SUCCEEDED" |
| 100 | + ) |
76 | 101 |
|
| 102 | + mock_genai_client.return_value.batches.create.return_value = ( |
| 103 | + mock_batch_job_running |
| 104 | + ) |
| 105 | + mock_genai_client.return_value.batches.get.return_value = ( |
| 106 | + mock_batch_job_succeeded |
| 107 | + ) |
| 108 | + |
| 109 | + response = batchpredict_with_gcs.generate_content( |
| 110 | + output_uri="gs://test-bucket/test-prefix" |
| 111 | + ) |
77 | 112 |
|
78 | | -def test_batch_prediction_with_gcs(gcs_output_uri: str) -> None: |
79 | | - response = batchpredict_with_gcs.generate_content(output_uri=gcs_output_uri) |
| 113 | + mock_genai_client.assert_called_once_with( |
| 114 | + http_options=types.HttpOptions(api_version="v1") |
| 115 | + ) |
| 116 | + mock_genai_client.return_value.batches.create.assert_called_once() |
| 117 | + mock_genai_client.return_value.batches.get.assert_called_once() |
80 | 118 | assert response == JobState.JOB_STATE_SUCCEEDED |
81 | 119 |
|
82 | 120 |
|
83 | 121 | @patch("google.genai.Client") |
84 | 122 | def test_get_batch_job(mock_genai_client: MagicMock) -> None: |
85 | 123 | # Mock the API response |
86 | | - mock_batch_job = types.BatchJob( |
87 | | - name="test-batch-job", |
88 | | - state="JOB_STATE_PENDING" |
89 | | - ) |
| 124 | + mock_batch_job = types.BatchJob(name="test-batch-job", state="JOB_STATE_PENDING") |
90 | 125 |
|
91 | 126 | mock_genai_client.return_value.batches.get.return_value = mock_batch_job |
92 | 127 |
|
93 | 128 | response = get_batch_job.get_batch_job("test-batch-job") |
94 | 129 |
|
95 | | - mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1")) |
| 130 | + mock_genai_client.assert_called_once_with( |
| 131 | + http_options=types.HttpOptions(api_version="v1") |
| 132 | + ) |
96 | 133 | mock_genai_client.return_value.batches.get.assert_called_once() |
97 | 134 | assert response == mock_batch_job |
0 commit comments