Skip to content

Commit 3144e93

Browse files
irataxySita04
andauthored
feat: add Gemini batch prediction samples for GCS and BQ (#9592)
* feat: add Gemini batch prediction samples for GCS and BQ * fix region tag to align with Python sample * Update aiplatform/src/main/java/aiplatform/batchpredict/CreateBatchPredictionGeminiBigqueryJobSample.java Co-authored-by: Sita Lakshmi Sangameswaran <[email protected]> * Update aiplatform/src/main/java/aiplatform/batchpredict/CreateBatchPredictionGeminiBigqueryJobSample.java Co-authored-by: Sita Lakshmi Sangameswaran <[email protected]> * Update aiplatform/src/main/java/aiplatform/batchpredict/CreateBatchPredictionGeminiBigqueryJobSample.java Co-authored-by: Sita Lakshmi Sangameswaran <[email protected]> * address feedback * fix lint and add example responses * address feedback --------- Co-authored-by: Sita Lakshmi Sangameswaran <[email protected]>
1 parent 721bac6 commit 3144e93

File tree

3 files changed

+321
-0
lines changed

3 files changed

+321
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform.batchpredict;
18+
19+
// [START generativeaionvertexai_batch_predict_gemini_createjob_bigquery]
20+
import com.google.cloud.aiplatform.v1.BatchPredictionJob;
21+
import com.google.cloud.aiplatform.v1.BigQueryDestination;
22+
import com.google.cloud.aiplatform.v1.BigQuerySource;
23+
import com.google.cloud.aiplatform.v1.JobServiceClient;
24+
import com.google.cloud.aiplatform.v1.JobServiceSettings;
25+
import com.google.cloud.aiplatform.v1.LocationName;
26+
import java.io.IOException;
27+
28+
public class CreateBatchPredictionGeminiBigqueryJobSample {
29+
30+
public static void main(String[] args) throws IOException {
31+
// TODO(developer): Update these variables before running the sample.
32+
String project = "PROJECT_ID";
33+
String bigqueryDestinationOutputUri = "bq://PROJECT_ID.MY_DATASET.MY_TABLE";
34+
35+
createBatchPredictionGeminiBigqueryJobSample(project, bigqueryDestinationOutputUri);
36+
}
37+
38+
// Create a batch prediction job using BigQuery input and output datasets.
39+
public static BatchPredictionJob createBatchPredictionGeminiBigqueryJobSample(
40+
String project, String bigqueryDestinationOutputUri) throws IOException {
41+
String location = "us-central1";
42+
JobServiceSettings settings =
43+
JobServiceSettings.newBuilder()
44+
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", location))
45+
.build();
46+
47+
// Initialize client that will be used to send requests. This client only needs to be created
48+
// once, and can be reused for multiple requests.
49+
try (JobServiceClient client = JobServiceClient.create(settings)) {
50+
BigQuerySource bigquerySource =
51+
BigQuerySource.newBuilder()
52+
.setInputUri("bq://storage-samples.generative_ai.batch_requests_for_multimodal_input")
53+
.build();
54+
BatchPredictionJob.InputConfig inputConfig =
55+
BatchPredictionJob.InputConfig.newBuilder()
56+
.setInstancesFormat("bigquery")
57+
.setBigquerySource(bigquerySource)
58+
.build();
59+
BigQueryDestination bigqueryDestination =
60+
BigQueryDestination.newBuilder().setOutputUri(bigqueryDestinationOutputUri).build();
61+
BatchPredictionJob.OutputConfig outputConfig =
62+
BatchPredictionJob.OutputConfig.newBuilder()
63+
.setPredictionsFormat("bigquery")
64+
.setBigqueryDestination(bigqueryDestination)
65+
.build();
66+
String modelName =
67+
String.format(
68+
"projects/%s/locations/%s/publishers/google/models/%s",
69+
project, location, "gemini-1.5-flash-002");
70+
71+
BatchPredictionJob batchPredictionJob =
72+
BatchPredictionJob.newBuilder()
73+
.setDisplayName("my-display-name")
74+
.setModel(modelName) // Add model parameters per request in the input BigQuery table.
75+
.setInputConfig(inputConfig)
76+
.setOutputConfig(outputConfig)
77+
.build();
78+
79+
LocationName parent = LocationName.of(project, location);
80+
BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob);
81+
System.out.format("\tName: %s\n", response.getName());
82+
// Example response:
83+
// Name: projects/<project>/locations/us-central1/batchPredictionJobs/<job-id>
84+
return response;
85+
}
86+
}
87+
}
88+
89+
// [END generativeaionvertexai_batch_predict_gemini_createjob_bigquery]
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform.batchpredict;
18+
19+
// [START generativeaionvertexai_batch_predict_gemini_createjob_gcs]
20+
import com.google.cloud.aiplatform.v1.BatchPredictionJob;
21+
import com.google.cloud.aiplatform.v1.GcsDestination;
22+
import com.google.cloud.aiplatform.v1.GcsSource;
23+
import com.google.cloud.aiplatform.v1.JobServiceClient;
24+
import com.google.cloud.aiplatform.v1.JobServiceSettings;
25+
import com.google.cloud.aiplatform.v1.LocationName;
26+
import java.io.IOException;
27+
28+
public class CreateBatchPredictionGeminiJobSample {
29+
30+
public static void main(String[] args) throws IOException {
31+
// TODO(developer): Update these variables before running the sample.
32+
String project = "PROJECT_ID";
33+
String gcsDestinationOutputUriPrefix = "gs://MY_BUCKET/";
34+
35+
createBatchPredictionGeminiJobSample(project, gcsDestinationOutputUriPrefix);
36+
}
37+
38+
// Create a batch prediction job using a JSONL input file and output URI, both in Cloud
39+
// Storage.
40+
public static BatchPredictionJob createBatchPredictionGeminiJobSample(
41+
String project, String gcsDestinationOutputUriPrefix) throws IOException {
42+
String location = "us-central1";
43+
JobServiceSettings settings =
44+
JobServiceSettings.newBuilder()
45+
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", location))
46+
.build();
47+
48+
// Initialize client that will be used to send requests. This client only needs to be created
49+
// once, and can be reused for multiple requests.
50+
try (JobServiceClient client = JobServiceClient.create(settings)) {
51+
GcsSource gcsSource =
52+
GcsSource.newBuilder()
53+
.addUris(
54+
"gs://cloud-samples-data/generative-ai/batch/"
55+
+ "batch_requests_for_multimodal_input.jsonl")
56+
// Or try
57+
// "gs://cloud-samples-data/generative-ai/batch/gemini_multimodal_batch_predict.jsonl"
58+
// for a batch prediction that uses audio, video, and an image.
59+
.build();
60+
BatchPredictionJob.InputConfig inputConfig =
61+
BatchPredictionJob.InputConfig.newBuilder()
62+
.setInstancesFormat("jsonl")
63+
.setGcsSource(gcsSource)
64+
.build();
65+
GcsDestination gcsDestination =
66+
GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build();
67+
BatchPredictionJob.OutputConfig outputConfig =
68+
BatchPredictionJob.OutputConfig.newBuilder()
69+
.setPredictionsFormat("jsonl")
70+
.setGcsDestination(gcsDestination)
71+
.build();
72+
String modelName =
73+
String.format(
74+
"projects/%s/locations/%s/publishers/google/models/%s",
75+
project, location, "gemini-1.5-flash-002");
76+
77+
BatchPredictionJob batchPredictionJob =
78+
BatchPredictionJob.newBuilder()
79+
.setDisplayName("my-display-name")
80+
.setModel(modelName) // Add model parameters per request in the input jsonl file.
81+
.setInputConfig(inputConfig)
82+
.setOutputConfig(outputConfig)
83+
.build();
84+
85+
LocationName parent = LocationName.of(project, location);
86+
BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob);
87+
System.out.format("\tName: %s\n", response.getName());
88+
// Example response:
89+
// Name: projects/<project>/locations/us-central1/batchPredictionJobs/<job-id>
90+
return response;
91+
}
92+
}
93+
}
94+
95+
// [END generativeaionvertexai_batch_predict_gemini_createjob_gcs]
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
import static junit.framework.TestCase.assertNotNull;
20+
import static org.hamcrest.CoreMatchers.containsString;
21+
import static org.hamcrest.MatcherAssert.assertThat;
22+
23+
import aiplatform.batchpredict.CreateBatchPredictionGeminiBigqueryJobSample;
24+
import aiplatform.batchpredict.CreateBatchPredictionGeminiJobSample;
25+
import com.google.cloud.aiplatform.v1.BatchPredictionJob;
26+
import java.io.ByteArrayOutputStream;
27+
import java.io.IOException;
28+
import java.io.PrintStream;
29+
import java.time.Instant;
30+
import java.util.concurrent.ExecutionException;
31+
import java.util.concurrent.TimeUnit;
32+
import java.util.concurrent.TimeoutException;
33+
import org.junit.AfterClass;
34+
import org.junit.BeforeClass;
35+
import org.junit.Test;
36+
import org.junit.runner.RunWith;
37+
import org.junit.runners.JUnit4;
38+
39+
@RunWith(JUnit4.class)
40+
public class CreateBatchPredictionGeminiJobSampleTest {
41+
42+
private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
43+
private static final String GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/";
44+
private static final String now = String.valueOf(Instant.now().getEpochSecond());
45+
private static final String BIGQUERY_DESTINATION_OUTPUT_URI_PREFIX =
46+
String.format("bq://%s.gen_ai_batch_prediction.predictions_%s", PROJECT, now);
47+
48+
private static ByteArrayOutputStream bout;
49+
private static PrintStream originalPrintStream;
50+
private static String batchPredictionGcsJobId;
51+
private static String batchPredictionBqJobId;
52+
53+
private static void requireEnvVar(String varName) {
54+
String errorMessage =
55+
String.format("Environment variable '%s' is required to perform these tests.", varName);
56+
assertNotNull(errorMessage, System.getenv(varName));
57+
}
58+
59+
@BeforeClass
60+
public static void checkRequirements() {
61+
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
62+
requireEnvVar("UCAIP_PROJECT_ID");
63+
}
64+
65+
@AfterClass
66+
public static void tearDown()
67+
throws InterruptedException, ExecutionException, IOException, TimeoutException {
68+
// Set up
69+
bout = new ByteArrayOutputStream();
70+
PrintStream out = new PrintStream(bout);
71+
originalPrintStream = System.out;
72+
System.setOut(out);
73+
74+
// Cloud Storage job
75+
CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionGcsJobId);
76+
77+
// Assert
78+
String cancelResponse = bout.toString();
79+
assertThat(cancelResponse, containsString("Cancelled the Batch Prediction Job"));
80+
TimeUnit.MINUTES.sleep(2);
81+
82+
// Delete the Batch Prediction Job
83+
DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionGcsJobId);
84+
85+
// Assert
86+
String deleteResponse = bout.toString();
87+
assertThat(deleteResponse, containsString("Deleted Batch"));
88+
89+
// BigQuery job
90+
CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionBqJobId);
91+
92+
// Assert
93+
cancelResponse = bout.toString();
94+
assertThat(cancelResponse, containsString("Cancelled the Batch Prediction Job"));
95+
TimeUnit.MINUTES.sleep(2);
96+
97+
// Delete the Batch Prediction Job
98+
DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionBqJobId);
99+
100+
// Assert
101+
deleteResponse = bout.toString();
102+
assertThat(deleteResponse, containsString("Deleted Batch"));
103+
104+
System.out.flush();
105+
System.setOut(originalPrintStream);
106+
}
107+
108+
@Test
109+
public void testCreateBatchPredictionGeminiJobSampleTest() throws IOException {
110+
// Cloud Storage job
111+
// Act
112+
BatchPredictionJob job =
113+
CreateBatchPredictionGeminiJobSample.createBatchPredictionGeminiJobSample(
114+
PROJECT, GCS_OUTPUT_URI);
115+
116+
// Assert
117+
assertThat(job.getName(), containsString("batchPredictionJobs"));
118+
119+
String[] id = job.getName().split("/");
120+
batchPredictionGcsJobId = id[id.length - 1];
121+
}
122+
123+
@Test
124+
public void testCreateBatchPredictionGeminiBigqueryJobSampleTest() throws IOException {
125+
// BigQuery job
126+
// Act
127+
BatchPredictionJob job =
128+
CreateBatchPredictionGeminiBigqueryJobSample.createBatchPredictionGeminiBigqueryJobSample(
129+
PROJECT, BIGQUERY_DESTINATION_OUTPUT_URI_PREFIX);
130+
131+
// Assert
132+
assertThat(job.getName(), containsString("batchPredictionJobs"));
133+
134+
String[] id = job.getName().split("/");
135+
batchPredictionBqJobId = id[id.length - 1];
136+
}
137+
}

0 commit comments

Comments
 (0)