Skip to content

Commit 40151f3

Browse files
anthropic batch predict samples
1 parent 03318a1 commit 40151f3

File tree

4 files changed

+160
-1
lines changed

4 files changed

+160
-1
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
17+
18+
output_uri = "bq://storage-samples.generative_ai.gen_ai_batch_prediction.predictions"
19+
20+
21+
def batch_predict_gemini_createjob(output_uri: str) -> str:
22+
"""Perform batch text prediction using a Gemini AI model and returns the output location"""
23+
24+
# [START generativeaionvertexai_batch_predict_anthropic_gemini_createjob_bigquery]
25+
import time
26+
import vertexai
27+
28+
from vertexai.batch_prediction import BatchPredictionJob
29+
30+
# TODO(developer): Update and un-comment below line
31+
# PROJECT_ID = "your-project-id"
32+
33+
# Initialize vertexai
34+
vertexai.init(project=PROJECT_ID, location="us-east5")
35+
36+
input_uri = "bq://kellysun-test-project.bp_llm_input.claude_50_requests"
37+
38+
# Submit a batch prediction job with Gemini model
39+
batch_prediction_job = BatchPredictionJob.submit(
40+
source_model="publishers/anthropic/models/claude-3-5-haiku",
41+
input_dataset=input_uri,
42+
output_uri_prefix=output_uri,
43+
)
44+
45+
# Check job status
46+
print(f"Job resource name: {batch_prediction_job.resource_name}")
47+
print(f"Model resource name with the job: {batch_prediction_job.model_name}")
48+
print(f"Job state: {batch_prediction_job.state.name}")
49+
50+
# Refresh the job until complete
51+
while not batch_prediction_job.has_ended:
52+
time.sleep(5)
53+
batch_prediction_job.refresh()
54+
55+
# Check if the job succeeds
56+
if batch_prediction_job.has_succeeded:
57+
print("Job succeeded!")
58+
else:
59+
print(f"Job failed: {batch_prediction_job.error}")
60+
61+
# Check the location of the output
62+
print(f"Job output location: {batch_prediction_job.output_location}")
63+
64+
# Example response:
65+
# Job output location: bq://Project-ID/gen-ai-batch-prediction/predictions-model-year-month-day-hour:minute:second.12345
66+
# [END generativeaionvertexai_batch_predict_anthropic_gemini_createjob_bigquery]
67+
return batch_prediction_job
68+
69+
70+
if __name__ == "__main__":
71+
batch_predict_gemini_createjob(output_uri)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
2+
# Copyright 2024 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import os
16+
17+
18+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
19+
20+
output_uri = "gs://python-docs-samples-tests"
21+
22+
23+
def batch_predict_createjob(output_uri: str) -> str:
24+
"Perform batch text prediction using a Gemini AI model and returns the output location"
25+
26+
# [START generativeaionvertexai_batch_predict_anthropic_gemini_createjob]
27+
import time
28+
import vertexai
29+
30+
from vertexai.batch_prediction import BatchPredictionJob
31+
32+
# TODO(developer): Update and un-comment below line
33+
# PROJECT_ID = "your-project-id"
34+
35+
# Initialize vertexai
36+
vertexai.init(project=PROJECT_ID, location="us-east5")
37+
38+
input_uri = "gs://kellysun-test-project-europe-west1/input/claude_varied_input.jsonl"
39+
40+
# Submit a batch prediction job with Gemini model
41+
batch_prediction_job = BatchPredictionJob.submit(
42+
source_model="publishers/anthropic/models/claude-3-5-haiku",
43+
input_dataset=input_uri,
44+
output_uri_prefix=output_uri,
45+
)
46+
47+
# Check job status
48+
print(f"Job resource name: {batch_prediction_job.resource_name}")
49+
print(f"Model resource name with the job: {batch_prediction_job.model_name}")
50+
print(f"Job state: {batch_prediction_job.state.name}")
51+
52+
# Refresh the job until complete
53+
while not batch_prediction_job.has_ended:
54+
time.sleep(5)
55+
batch_prediction_job.refresh()
56+
57+
# Check if the job succeeds
58+
if batch_prediction_job.has_succeeded:
59+
print("Job succeeded!")
60+
else:
61+
print(f"Job failed: {batch_prediction_job.error}")
62+
63+
# Check the location of the output
64+
print(f"Job output location: {batch_prediction_job.output_location}")
65+
66+
# Example response:
67+
# Job output location: gs://your-bucket/gen-ai-batch-prediction/prediction-model-year-month-day-hour:minute:second.12345
68+
69+
# [END generativeaionvertexai_batch_predict_anthropic_gemini_createjob]
70+
return batch_prediction_job
71+
72+
73+
if __name__ == "__main__":
74+
batch_predict_createjob(output_uri)

generative_ai/model_garden/noxfile_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
TEST_CONFIG_OVERRIDE = {
2424
# You can opt out from the test for specific Python versions.
25-
"ignored_versions": ["2.7", "3.7", "3.8", "3.10", "3.11", "3.13"],
25+
"ignored_versions": ["2.7", "3.7", "3.8", "3.10", "3.11", "3.12"],
2626
# Old samples are opted out of enforcing Python type hints
2727
# All new samples should feature them
2828
"enforce_type_hints": True,

generative_ai/model_garden/test_model_garden_examples.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from google.api_core.exceptions import ResourceExhausted
1818

19+
import claude_3_batch_prediciton_bq
20+
import claude_3_batch_prediction_gcs
1921
import claude_3_streaming_example
2022
import claude_3_tool_example
2123
import claude_3_unary_example
@@ -40,3 +42,15 @@ def test_tool_use() -> None:
4042
def test_generate_text() -> None:
4143
responses = claude_3_unary_example.generate_text()
4244
assert "bread" in responses.model_dump_json(indent=2)
45+
46+
47+
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
48+
def test_generate_text_gcs() -> None:
49+
responses = claude_3_batch_prediction_gcs.generate_text()
50+
assert "bread" in responses.model_dump_json(indent=2)
51+
52+
53+
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
54+
def test_generate_text_bq() -> None:
55+
responses = claude_3_batch_prediciton_bq.generate_text()
56+
assert "bread" in responses.model_dump_json(indent=2)

0 commit comments

Comments
 (0)