1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import os
15+ from typing import Callable
1416
1517import backoff
16-
1718from google .api_core .exceptions import ResourceExhausted
19+ from google .cloud import storage
20+ from google .cloud .aiplatform import BatchPredictionJob
21+ from google .cloud .aiplatform_v1 import JobState
22+ import pytest
1823
1924import claude_3_batch_prediciton_bq
2025import claude_3_batch_prediction_gcs
2126import claude_3_streaming_example
2227import claude_3_tool_example
2328import claude_3_unary_example
2429
30+ PROJECT_ID = os .getenv ("GOOGLE_CLOUD_PROJECT" )
31+
32+ INPUT_BUCKET = "kellysun-test-project-europe-west1"
33+ OUTPUT_BUCKET = "python-docs-samples-tests"
34+ OUTPUT_PATH = "batch/batch_text_predict_output"
35+ GCS_OUTPUT_PATH = "gs://python-docs-samples-tests/"
36+ OUTPUT_TABLE = f"bq://{ PROJECT_ID } .gen_ai_batch_prediction.predictions"
37+
38+
39+ def _clean_resources () -> None :
40+ storage_client = storage .Client ()
41+ bucket = storage_client .get_bucket (OUTPUT_BUCKET )
42+ blobs = bucket .list_blobs (prefix = OUTPUT_PATH )
43+ for blob in blobs :
44+ blob .delete ()
45+
46+
47+ @pytest .fixture (scope = "session" )
48+ def output_folder () -> str :
49+ yield f"gs://{ OUTPUT_BUCKET } /{ OUTPUT_PATH } "
50+ _clean_resources ()
51+
52+
53+ def _main_test (test_func : Callable ) -> BatchPredictionJob :
54+ job = None
55+ try :
56+ job = test_func ()
57+ assert job .state == JobState .JOB_STATE_SUCCEEDED
58+ return job
59+ finally :
60+ if job is not None :
61+ job .delete ()
62+
2563
2664@backoff .on_exception (backoff .expo , ResourceExhausted , max_time = 10 )
2765def test_generate_text_streaming () -> None :
@@ -44,13 +82,21 @@ def test_generate_text() -> None:
4482 assert "bread" in responses .model_dump_json (indent = 2 )
4583
4684
47- @backoff .on_exception (backoff .expo , ResourceExhausted , max_time = 10 )
48- def test_generate_text_gcs () -> None :
49- responses = claude_3_batch_prediction_gcs .generate_text ()
50- assert "bread" in responses .model_dump_json (indent = 2 )
85+ def test_batch_gemini_predict_gcs (output_folder : pytest .fixture ()) -> None :
86+ output_uri = "gs://python-docs-samples-tests"
87+ job = _main_test (
88+ test_func = lambda : claude_3_batch_prediction_gcs .batch_predict_gemini_createjob (
89+ output_uri
90+ )
91+ )
92+ assert GCS_OUTPUT_PATH in job .output_location
5193
5294
53- @backoff .on_exception (backoff .expo , ResourceExhausted , max_time = 10 )
54- def test_generate_text_bq () -> None :
55- responses = claude_3_batch_prediciton_bq .generate_text ()
56- assert "bread" in responses .model_dump_json (indent = 2 )
95+ def test_batch_gemini_predict_bigquery (output_folder : pytest .fixture ()) -> None :
96+ output_uri = f"bq://{ PROJECT_ID } .gen_ai_batch_prediction.predictions"
97+ job = _main_test (
98+ test_func = lambda : claude_3_batch_prediciton_bq .batch_predict_gemini_createjob (
99+ output_uri
100+ )
101+ )
102+ assert OUTPUT_TABLE in job .output_location
0 commit comments