Skip to content

Commit 65bb552

Browse files
feat: batch predict with GCS
1 parent 0e0e640 commit 65bb552

File tree

4 files changed

+121
-0
lines changed

4 files changed

+121
-0
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.genai.types import BatchJob
16+
17+
18+
def create_job(output_uri: str) -> BatchJob:
19+
# [START googlegenaisdk_batch_prediction_with_gcs]
20+
import time
21+
22+
from google import genai
23+
24+
client = genai.Client()
25+
# TODO(developer): Update and un-comment below line
26+
# output_uri = "gs://your-bucket/your-prefix/..."
27+
28+
job = client.batches.create(
29+
model="gemini-2.0-flash-001",
30+
src="gs://cloud-samples-data/batch/prompt_for_batch_gemini_predict.jsonl",
31+
config={
32+
"dest": output_uri
33+
}
34+
)
35+
print(f"Job name: {job.name}")
36+
print(f"Job state: {job.state}")
37+
# Example response:
38+
# Job name: projects/%PROJECT_ID%/locations/us-central1/batchPredictionJobs/9876453210000000000
39+
# Job state: JOB_STATE_PENDING
40+
41+
# See the documentation: https://googleapis.github.io/python-genai/genai.html#genai.types.BatchJob
42+
completed_states = [
43+
"JOB_STATE_SUCCEEDED",
44+
"JOB_STATE_FAILED",
45+
"JOB_STATE_CANCELLED",
46+
"JOB_STATE_PAUSED",
47+
]
48+
while job.state not in completed_states:
49+
time.sleep(30)
50+
job = client.batches.get(name=job.name)
51+
print(f"Job state: {job.state}")
52+
# Example response:
53+
# Job state: JOB_STATE_PENDING
54+
# Job state: JOB_STATE_RUNNING
55+
# Job state: JOB_STATE_RUNNING
56+
# ...
57+
# Job state: JOB_STATE_SUCCEEDED
58+
59+
# [END googlegenaisdk_batch_prediction_with_gcs]
60+
return job
61+
62+
63+
if __name__ == "__main__":
64+
create_job(
65+
output_uri="gs://your-bucket/your-prefix/..."
66+
)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
google-api-core==2.24.0
2+
google-cloud-storage==2.19.0
3+
pytest==8.2.0
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
google-genai==0.7.0
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from datetime import datetime as dt, UTC
16+
import os
17+
18+
from google.cloud import storage
19+
from google.genai.types import JobState
20+
import pytest
21+
22+
import batch_prediction_with_gcs
23+
24+
25+
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "True"
26+
os.environ["GOOGLE_CLOUD_LOCATION"] = "us-central1"
27+
# The project name is included in the CICD pipeline
28+
# os.environ['GOOGLE_CLOUD_PROJECT'] = "add-your-project-name"
29+
GCS_OUTPUT_BUCKET = "python-docs-samples-tests"
30+
GCS_OUTPUT_BUCKET = "gemini-batch-prediction-results"
31+
32+
33+
@pytest.fixture(scope="session")
34+
def gcs_output_uri():
35+
prefix = f"text_output/{dt.now(UTC)}"
36+
37+
yield f"gs://{GCS_OUTPUT_BUCKET}/{prefix}"
38+
39+
storage_client = storage.Client()
40+
bucket = storage_client.get_bucket(GCS_OUTPUT_BUCKET)
41+
blobs = bucket.list_blobs(prefix=prefix)
42+
for blob in blobs:
43+
blob.delete()
44+
45+
46+
def test_batch_prediction_with_gcs(gcs_output_uri) -> None:
47+
job = batch_prediction_with_gcs.create_job(output_uri=gcs_output_uri)
48+
assert job
49+
assert job.state == "JOB_STATE_SUCCEEDED"
50+
assert job.dest.gcs_uri == gcs_output_uri
51+
assert job.state == JobState.JOB_STATE_SUCCEEDED

0 commit comments

Comments
 (0)