Skip to content

Commit 3df3a3c

Browse files
adding batch gemini predict sample (#12621)
* feat(generativeai): Add batch-predict sample using Gemini model
1 parent 07b1b8b commit 3df3a3c

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
from vertexai.preview.batch_prediction import BatchPredictionJob
17+
18+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
19+
20+
21+
def batch_predict_gemini_createjob(input_uri: str, output_uri: str) -> BatchPredictionJob:
22+
"""Perform batch text prediction using a Gemini AI model.
23+
Args:
24+
input_uri (str): URI of the input file in BigQuery table or Google Cloud Storage.
25+
Example: "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
26+
27+
output_uri (str): URI of the output folder, in BigQuery table or Google Cloud Storage.
28+
Example: "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
29+
Returns:
30+
batch_prediction_job: The batch prediction job object containing details of the job.
31+
"""
32+
33+
# [START generativeaionvertexai_batch_predict_gemini_createjob]
34+
import time
35+
import vertexai
36+
37+
from vertexai.preview.batch_prediction import BatchPredictionJob
38+
39+
# TODO(developer): Update and un-comment below lines
40+
# input_uri ="gs://[BUCKET]/[OUTPUT].jsonl" # Example
41+
# output_uri ="gs://[BUCKET]"
42+
43+
# Initialize vertexai
44+
vertexai.init(project=PROJECT_ID, location="us-central1")
45+
46+
# Submit a batch prediction job with Gemini model
47+
batch_prediction_job = BatchPredictionJob.submit(
48+
source_model="gemini-1.5-flash-001",
49+
input_dataset=input_uri,
50+
output_uri_prefix=output_uri
51+
)
52+
53+
# Check job status
54+
print(f"Job resouce name: {batch_prediction_job.resource_name}")
55+
print(f"Model resource name with the job: {batch_prediction_job.model_name}")
56+
print(f"Job state: {batch_prediction_job.state.name}")
57+
58+
# Refresh the job until complete
59+
while not batch_prediction_job.has_ended:
60+
time.sleep(5)
61+
batch_prediction_job.refresh()
62+
63+
# Check if the job succeeds
64+
if batch_prediction_job.has_succeeded:
65+
print("Job succeeded!")
66+
else:
67+
print(f"Job failed: {batch_prediction_job.error}")
68+
69+
# Check the location of the output
70+
print(f"Job output location: {batch_prediction_job.output_location}")
71+
72+
# Example response:
73+
# Job output location: gs://yourbucket/gen-ai-batch-prediction/prediction-model-year-month-day-hour:minute:second.12345
74+
75+
# https://storage.googleapis.com/cloud-samples-data/batch/prompt_for_batch_gemini_predict.jsonl
76+
77+
return batch_prediction_job
78+
79+
# [END generativeaionvertexai_batch_predict_gemini_createjob]
80+
81+
82+
if __name__ == "__main__":
83+
# TODO(developer): Update gsc bucket and file paths
84+
GCS_BUCKET = "gs://yourbucket"
85+
batch_predict_gemini_createjob(f"gs://{GCS_BUCKET}/batch_data/sample_input_file.jsonl",
86+
f"gs://{GCS_BUCKET}/batch_preditions/sample_output/")

generative_ai/batch_predict/test_batch_predict.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
# limitations under the License.
1414
from typing import Callable
1515

16+
1617
import batch_code_predict
1718
import batch_text_predict
19+
import gemini_batch_predict
20+
1821

1922
from google.cloud import storage
2023
from google.cloud.aiplatform import BatchPredictionJob
@@ -70,3 +73,12 @@ def test_batch_code_predict(output_folder: pytest.fixture()) -> None:
7073
)
7174
)
7275
assert OUTPUT_PATH in job.output_info.gcs_output_directory
76+
77+
78+
def test_batch_gemini_predict(output_folder: pytest.fixture()) -> None:
79+
input_uri = f"gs://{INPUT_BUCKET}/batch/prompt_for_batch_gemini_predict.jsonl"
80+
job = _main_test(
81+
test_func=lambda: gemini_batch_predict.batch_predict_gemini_createjob(
82+
input_uri, output_folder)
83+
)
84+
assert OUTPUT_PATH in job.output_location

0 commit comments

Comments
 (0)