Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions generative_ai/batch_predict/gemini_batch_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.cloud.aiplatform import BatchPredictionJob


# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.cloud.aiplatform import BatchPredictionJob

import os


PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
# LOCATION = "us-central1"


def batch_prediction_gemini_example(
input_uri: str = None, output_uri: str = None
) -> BatchPredictionJob:
"""Perform batch text prediction using a Gemini AI model.
Args:
input_uri (str, optional): URI of the input dataset. Could be a BigQuery table or a Google Cloud Storage file.
E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
output_uri (str, optional): URI where the output will be stored.
Could be a BigQuery table or a Google Cloud Storage file.
E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
Returns:
batch_prediction_job: The batch prediction job object containing details of the job.
"""=

# [START generativeaionvertexai_batch_predict_gemini_createjob]
from vertexai.preview.batch_prediction import BatchPredictionJob

# TODO(developer): Update and un-comment below lines
# PROJECT_ID = 'example_project'
# input_uri = "bq://example_project.example_dataset.example_table"

# Initialize vertexai
vertexai.init(project=PROJECT, location="us-central1")

# Submit a batch prediction job with Gemini model
job = BatchPredictionJob.submit("gemini-1.5-flash-001", input_uri)

# Check job status
print(f"Job resouce name: {job.resource_name}")
print(f"Model resource name with the job: {job.model_name}")
print(f"Job state: {job.state.name}")

# Refresh the job until complete
while not job.has_ended:
time.sleep(5)
job.refresh()

# Check if the job succeeds
if job.has_succeeded:
print("Job succeeded!")
else:
print(f"Job failed: {job.error}")

# Check the location of the output
print(f"Job output location: {job.output_location}")
# Example response:
# bq://example_project.gen_ai_batch_prediction.predictions_2024-09-17-11-09-45-ABC123
# [END generativeaionvertexai_batch_predict_gemini_createjob]
return batch_prediction_job


if __name__ == "__main__":
batch_prediction_gemini_example()
9 changes: 9 additions & 0 deletions generative_ai/batch_predict/test_batch_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import batch_code_predict
import batch_text_predict
import batch_prediction_gemini_example

from google.cloud import storage
from google.cloud.aiplatform import BatchPredictionJob
Expand Down Expand Up @@ -70,3 +71,11 @@ def test_batch_code_predict(output_folder: pytest.fixture()) -> None:
)
)
assert OUTPUT_PATH in job.output_info.gcs_output_directory


def test_batch_prediction_gemini_example(output_folder: pytest.fixture()) -> None:
input_uri = f"gs://{INPUT_BUCKET}/batch/prompt_for_batch_code_predict.jsonl"
job = _main_test(
test_func=lambda: batch_prediction_gemini_example.gemini_batch_prediction_example(input_uri)
)
assert job.output_info.gcs_output_directory
Loading