Skip to content

Commit 23cab80

Browse files
updating testing
1 parent 40151f3 commit 23cab80

File tree

1 file changed

+55
-9
lines changed

1 file changed

+55
-9
lines changed

generative_ai/model_garden/test_model_garden_examples.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,55 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
15+
from typing import Callable
1416

1517
import backoff
16-
1718
from google.api_core.exceptions import ResourceExhausted
19+
from google.cloud import storage
20+
from google.cloud.aiplatform import BatchPredictionJob
21+
from google.cloud.aiplatform_v1 import JobState
22+
import pytest
1823

1924
import claude_3_batch_prediciton_bq
2025
import claude_3_batch_prediction_gcs
2126
import claude_3_streaming_example
2227
import claude_3_tool_example
2328
import claude_3_unary_example
2429

30+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
31+
32+
INPUT_BUCKET = "kellysun-test-project-europe-west1"
33+
OUTPUT_BUCKET = "python-docs-samples-tests"
34+
OUTPUT_PATH = "batch/batch_text_predict_output"
35+
GCS_OUTPUT_PATH = "gs://python-docs-samples-tests/"
36+
OUTPUT_TABLE = f"bq://{PROJECT_ID}.gen_ai_batch_prediction.predictions"
37+
38+
39+
def _clean_resources() -> None:
40+
storage_client = storage.Client()
41+
bucket = storage_client.get_bucket(OUTPUT_BUCKET)
42+
blobs = bucket.list_blobs(prefix=OUTPUT_PATH)
43+
for blob in blobs:
44+
blob.delete()
45+
46+
47+
@pytest.fixture(scope="session")
48+
def output_folder() -> str:
49+
yield f"gs://{OUTPUT_BUCKET}/{OUTPUT_PATH}"
50+
_clean_resources()
51+
52+
53+
def _main_test(test_func: Callable) -> BatchPredictionJob:
54+
job = None
55+
try:
56+
job = test_func()
57+
assert job.state == JobState.JOB_STATE_SUCCEEDED
58+
return job
59+
finally:
60+
if job is not None:
61+
job.delete()
62+
2563

2664
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
2765
def test_generate_text_streaming() -> None:
@@ -44,13 +82,21 @@ def test_generate_text() -> None:
4482
assert "bread" in responses.model_dump_json(indent=2)
4583

4684

47-
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
48-
def test_generate_text_gcs() -> None:
49-
responses = claude_3_batch_prediction_gcs.generate_text()
50-
assert "bread" in responses.model_dump_json(indent=2)
85+
def test_batch_gemini_predict_gcs(output_folder: pytest.fixture()) -> None:
86+
output_uri = "gs://python-docs-samples-tests"
87+
job = _main_test(
88+
test_func=lambda: claude_3_batch_prediction_gcs.batch_predict_gemini_createjob(
89+
output_uri
90+
)
91+
)
92+
assert GCS_OUTPUT_PATH in job.output_location
5193

5294

53-
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
54-
def test_generate_text_bq() -> None:
55-
responses = claude_3_batch_prediciton_bq.generate_text()
56-
assert "bread" in responses.model_dump_json(indent=2)
95+
def test_batch_gemini_predict_bigquery(output_folder: pytest.fixture()) -> None:
96+
output_uri = f"bq://{PROJECT_ID}.gen_ai_batch_prediction.predictions"
97+
job = _main_test(
98+
test_func=lambda: claude_3_batch_prediciton_bq.batch_predict_gemini_createjob(
99+
output_uri
100+
)
101+
)
102+
assert OUTPUT_TABLE in job.output_location

0 commit comments

Comments
 (0)