Skip to content

Commit 567844f

Browse files
jdomingrJuan Dominguez
andauthored
feat(genai): add batch prediction samples (2) (#10190)
* feat(genai): add batch prediction with bq sample * refactor(genai): change polling logic and update tests * fix(genai): fix test errors --------- Co-authored-by: Juan Dominguez <[email protected]>
1 parent ac89f31 commit 567844f

File tree

2 files changed

+131
-0
lines changed

2 files changed

+131
-0
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package genai.batchprediction;
18+
19+
// [START googlegenaisdk_batchpredict_with_bq]
20+
21+
import static com.google.genai.types.JobState.Known.JOB_STATE_CANCELLED;
22+
import static com.google.genai.types.JobState.Known.JOB_STATE_FAILED;
23+
import static com.google.genai.types.JobState.Known.JOB_STATE_PAUSED;
24+
import static com.google.genai.types.JobState.Known.JOB_STATE_SUCCEEDED;
25+
26+
import com.google.genai.Client;
27+
import com.google.genai.types.BatchJob;
28+
import com.google.genai.types.BatchJobDestination;
29+
import com.google.genai.types.BatchJobSource;
30+
import com.google.genai.types.CreateBatchJobConfig;
31+
import com.google.genai.types.GetBatchJobConfig;
32+
import com.google.genai.types.HttpOptions;
33+
import com.google.genai.types.JobState;
34+
import java.util.EnumSet;
35+
import java.util.Set;
36+
import java.util.concurrent.TimeUnit;
37+
38+
public class BatchPredictionWithBq {
39+
40+
public static void main(String[] args) throws InterruptedException {
41+
// TODO(developer): Replace these variables before running the sample.
42+
// To use a tuned model, set the model param to your tuned model using the following format:
43+
// modelId = "projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID}
44+
String modelId = "gemini-2.5-flash";
45+
String outputUri = "bq://your-project.your_dataset.your_table";
46+
createBatchJob(modelId, outputUri);
47+
}
48+
49+
// Creates a batch prediction job with Google BigQuery.
50+
public static JobState createBatchJob(String modelId, String outputUri)
51+
throws InterruptedException {
52+
// Client Initialization. Once created, it can be reused for multiple requests.
53+
try (Client client =
54+
Client.builder()
55+
.location("us-central1")
56+
.vertexAI(true)
57+
.httpOptions(HttpOptions.builder().apiVersion("v1").build())
58+
.build()) {
59+
60+
// See the documentation:
61+
// https://googleapis.github.io/java-genai/javadoc/com/google/genai/Batches.html
62+
BatchJobSource batchJobSource =
63+
BatchJobSource.builder()
64+
.bigqueryUri("bq://storage-samples.generative_ai.batch_requests_for_multimodal_input")
65+
.format("bigquery")
66+
.build();
67+
68+
CreateBatchJobConfig batchJobConfig =
69+
CreateBatchJobConfig.builder()
70+
.displayName("your-display-name")
71+
.dest(BatchJobDestination.builder().bigqueryUri(outputUri).format("bigquery").build())
72+
.build();
73+
74+
BatchJob batchJob = client.batches.create(modelId, batchJobSource, batchJobConfig);
75+
76+
String jobName =
77+
batchJob.name().orElseThrow(() -> new IllegalStateException("Missing job name"));
78+
JobState jobState =
79+
batchJob.state().orElseThrow(() -> new IllegalStateException("Missing job state"));
80+
System.out.println("Job name: " + jobName);
81+
System.out.println("Job state: " + jobState);
82+
// Job name:
83+
// projects/.../locations/.../batchPredictionJobs/3189981423167602688
84+
// Job state: JOB_STATE_PENDING
85+
86+
// See the documentation:
87+
// https://googleapis.github.io/java-genai/javadoc/com/google/genai/types/BatchJob.html
88+
Set<JobState.Known> completedStates =
89+
EnumSet.of(JOB_STATE_SUCCEEDED, JOB_STATE_FAILED, JOB_STATE_CANCELLED, JOB_STATE_PAUSED);
90+
91+
while (!completedStates.contains(jobState.knownEnum())) {
92+
TimeUnit.SECONDS.sleep(30);
93+
batchJob = client.batches.get(jobName, GetBatchJobConfig.builder().build());
94+
jobState =
95+
batchJob
96+
.state()
97+
.orElseThrow(() -> new IllegalStateException("Missing job state during polling"));
98+
System.out.println("Job state: " + jobState);
99+
}
100+
// Example response:
101+
// Job state: JOB_STATE_QUEUED
102+
// Job state: JOB_STATE_RUNNING
103+
// Job state: JOB_STATE_RUNNING
104+
// ...
105+
// Job state: JOB_STATE_SUCCEEDED
106+
return jobState;
107+
}
108+
}
109+
}
110+
// [END googlegenaisdk_batchpredict_with_bq]

genai/snippets/src/test/java/genai/batchprediction/BatchPredictionIT.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,27 @@ public void testBatchPredictionWithGcs() throws InterruptedException {
138138
assertThat(output).contains("Job state: JOB_STATE_SUCCEEDED");
139139
}
140140

141+
@Test
142+
public void testBatchPredictionWithBq() throws InterruptedException {
143+
// Act
144+
String outputBqUri = "bq://test-project.test_dataset.test_table";
145+
JobState response = BatchPredictionWithBq.createBatchJob(GEMINI_FLASH, outputBqUri);
146+
147+
// Assert
148+
verify(mockedBatches, times(1))
149+
.create(anyString(), any(BatchJobSource.class), any(CreateBatchJobConfig.class));
150+
verify(mockedBatches, times(2)).get(anyString(), any(GetBatchJobConfig.class));
151+
152+
assertThat(response).isNotNull();
153+
assertThat(response.knownEnum()).isEqualTo(JOB_STATE_SUCCEEDED);
154+
155+
String output = bout.toString();
156+
assertThat(output).contains("Job name: " + jobName);
157+
assertThat(output).contains("Job state: JOB_STATE_PENDING");
158+
assertThat(output).contains("Job state: JOB_STATE_RUNNING");
159+
assertThat(output).contains("Job state: JOB_STATE_SUCCEEDED");
160+
}
161+
141162
@Test
142163
public void testBatchPredictionEmbeddingsWithGcs() throws InterruptedException {
143164
// Act

0 commit comments

Comments
 (0)