| 
5 | 5 |  * you may not use this file except in compliance with the License.  | 
6 | 6 |  * You may obtain a copy of the License at  | 
7 | 7 |  *  | 
8 |  | - *     http://www.apache.org/licenses/LICENSE-2.0  | 
 | 8 | + * http://www.apache.org/licenses/LICENSE-2.0  | 
9 | 9 |  *  | 
10 | 10 |  * Unless required by applicable law or agreed to in writing, software  | 
11 | 11 |  * distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 
16 | 16 | 
 
  | 
17 | 17 | package aiplatform;  | 
18 | 18 | 
 
  | 
19 |  | -// [START aiplatform_batch_text_predict]  | 
 | 19 | +// [START generativeaionvertexai_batch_text_predict]  | 
20 | 20 | 
 
  | 
21 | 21 | import com.google.cloud.aiplatform.v1.BatchPredictionJob;  | 
22 |  | -import com.google.cloud.aiplatform.v1.BatchPredictionJob.InputConfig;  | 
23 |  | -import com.google.cloud.aiplatform.v1.BatchPredictionJob.OutputConfig;  | 
24 | 22 | import com.google.cloud.aiplatform.v1.GcsDestination;  | 
25 | 23 | import com.google.cloud.aiplatform.v1.GcsSource;  | 
26 | 24 | import com.google.cloud.aiplatform.v1.JobServiceClient;  | 
27 | 25 | import com.google.cloud.aiplatform.v1.JobServiceSettings;  | 
28 |  | -import com.google.cloud.aiplatform.v1.LocationName;  | 
 | 26 | +import com.google.gson.Gson;  | 
29 | 27 | import com.google.protobuf.InvalidProtocolBufferException;  | 
30 | 28 | import com.google.protobuf.Value;  | 
31 | 29 | import com.google.protobuf.util.JsonFormat;  | 
32 | 30 | import java.io.IOException;  | 
33 |  | - | 
 | 31 | +import java.util.HashMap;  | 
 | 32 | +import java.util.Map;  | 
 | 33 | +import java.util.concurrent.ExecutionException;  | 
 | 34 | +import java.util.concurrent.TimeoutException;  | 
34 | 35 | 
 
  | 
35 | 36 | public class BatchTextPredictionSample {  | 
36 | 37 | 
 
  | 
37 |  | -  public static void main(String[] args) throws IOException {  | 
38 |  | -    // TODO (Developer): Replace the input_uri and output_uri with your own GCS paths  | 
 | 38 | +  public static void main(String[] args)  | 
 | 39 | +      throws IOException, InterruptedException, ExecutionException, TimeoutException {  | 
 | 40 | +    // TODO(developer): Replace these variables before running the sample.  | 
39 | 41 |     String project = "YOUR_PROJECT_ID";  | 
40 | 42 |     String location = "us-central1";  | 
41 |  | -    // inputUri (str, optional): URI of the input dataset.  | 
 | 43 | +    // inputUri: URI of the input dataset.  | 
42 | 44 |     // Could be a BigQuery table or a Google Cloud Storage file.  | 
43 | 45 |     // E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"  | 
44 | 46 |     String inputUri = "gs://cloud-samples-data/batch/prompt_for_batch_text_predict.jsonl";  | 
45 |  | -    // outputUri (str, optional): URI where the output will be stored.  | 
 | 47 | +    // outputUri: URI where the output will be stored.  | 
46 | 48 |     // Could be a BigQuery table or a Google Cloud Storage file.  | 
47 | 49 |     // E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"  | 
48 |  | -    String outputUri = "gs://batch-bucket-testing/batch_text_predict_output";  | 
49 |  | -    String codeModel = "text-bison";  | 
 | 50 | +    String outputUri = "gs://YOUR_BUCKET/batch_text_predict_output";  | 
 | 51 | +    String textModel = "text-bison";  | 
50 | 52 | 
 
  | 
51 |  | -    batchTextPrediction(project, location, inputUri, outputUri, codeModel);  | 
 | 53 | +    batchTextPrediction(project, inputUri, outputUri, textModel, location);  | 
52 | 54 |   }  | 
53 | 55 | 
 
  | 
54 | 56 |   // Perform batch text prediction using a pre-trained text generation model.  | 
55 | 57 |   // Example of using Google Cloud Storage bucket as the input and output data source  | 
56 |  | -  public static void batchTextPrediction(  | 
57 |  | -      String project, String location, String inputUri,  | 
58 |  | -      String outputUri, String codeModel) throws IOException {  | 
59 |  | -    String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);  | 
60 |  | -    JobServiceSettings jobServiceSettings =  | 
61 |  | -        JobServiceSettings.newBuilder().setEndpoint(endpoint).build();  | 
62 |  | -    // Construct your modelParameters  | 
63 |  | -    String parameters =  | 
64 |  | -        "{\n" + "  \"temperature\": 0.2,\n" + "  \"maxOutputTokens\": 200\n" + "}";  | 
65 |  | -    Value parameterValue = stringToValue(parameters);  | 
 | 58 | +  static BatchPredictionJob batchTextPrediction(  | 
 | 59 | +      String projectId, String inputUri, String outputUri, String textModel, String location)  | 
 | 60 | +      throws IOException {  | 
 | 61 | +    BatchPredictionJob response;  | 
 | 62 | +    JobServiceSettings jobServiceSettings =  JobServiceSettings.newBuilder()  | 
 | 63 | +        .setEndpoint("us-central1-aiplatform.googleapis.com:443").build();  | 
 | 64 | +    String parent = String.format("projects/%s/locations/%s", projectId, location);  | 
66 | 65 |     String modelName = String.format(  | 
67 |  | -        "projects/%s/locations/%s/publishers/google/models/%s", project, location, codeModel);  | 
 | 66 | +        "projects/%s/locations/%s/publishers/google/models/%s", projectId, location, textModel);  | 
 | 67 | +    // Construct model parameters  | 
 | 68 | +    Map<String, String> modelParameters = new HashMap<>();  | 
 | 69 | +    modelParameters.put("maxOutputTokens", "200");  | 
 | 70 | +    modelParameters.put("temperature", "0.2");  | 
 | 71 | +    modelParameters.put("topP", "0.95");  | 
 | 72 | +    modelParameters.put("topK", "40");  | 
 | 73 | +    Value parameterValue = mapToValue(modelParameters);  | 
68 | 74 | 
 
  | 
69 | 75 |     // Initialize client that will be used to send requests. This client only needs to be created  | 
70 | 76 |     // once, and can be reused for multiple requests.  | 
71 | 77 |     try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) {  | 
72 | 78 | 
 
  | 
73 |  | -      GcsSource.Builder gcsSource = GcsSource.newBuilder();  | 
74 |  | -      gcsSource.addUris(inputUri);  | 
75 |  | -      InputConfig inputConfig =  | 
76 |  | -          InputConfig.newBuilder()  | 
77 |  | -              .setGcsSource(gcsSource)  | 
78 |  | -              .setInstancesFormat("jsonl")  | 
79 |  | -              .build();  | 
80 |  | - | 
81 |  | -      GcsDestination.Builder gcsDestination = GcsDestination.newBuilder();  | 
82 |  | -      gcsDestination.setOutputUriPrefix(outputUri);  | 
83 |  | -      OutputConfig outputConfig =  | 
84 |  | -          OutputConfig.newBuilder()  | 
85 |  | -              .setGcsDestination(gcsDestination)  | 
86 |  | -              .setPredictionsFormat("jsonl")  | 
87 |  | -              .build();  | 
88 |  | - | 
89 |  | -      BatchPredictionJob.Builder batchPredictionJob =  | 
 | 79 | +      BatchPredictionJob batchPredictionJob =  | 
90 | 80 |           BatchPredictionJob.newBuilder()  | 
91 | 81 |               .setDisplayName("my batch text prediction job " + System.currentTimeMillis())  | 
92 | 82 |               .setModel(modelName)  | 
93 |  | -              .setInputConfig(inputConfig)  | 
94 |  | -              .setOutputConfig(outputConfig)  | 
95 |  | -              .setModelParameters(parameterValue);  | 
 | 83 | +              .setInputConfig(  | 
 | 84 | +                  BatchPredictionJob.InputConfig.newBuilder()  | 
 | 85 | +                      .setGcsSource(GcsSource.newBuilder().addUris(inputUri).build())  | 
 | 86 | +                      .setInstancesFormat("jsonl")  | 
 | 87 | +                      .build())  | 
 | 88 | +              .setOutputConfig(  | 
 | 89 | +                  BatchPredictionJob.OutputConfig.newBuilder()  | 
 | 90 | +                      .setGcsDestination(GcsDestination.newBuilder()  | 
 | 91 | +                          .setOutputUriPrefix(outputUri).build())  | 
 | 92 | +                      .setPredictionsFormat("jsonl")  | 
 | 93 | +                      .build())  | 
 | 94 | +              .setModelParameters(parameterValue)  | 
 | 95 | +              .build();  | 
96 | 96 | 
 
  | 
97 |  | -      LocationName parent = LocationName.of(project, location);  | 
98 |  | -      BatchPredictionJob response =  | 
99 |  | -          jobServiceClient.createBatchPredictionJob(parent, batchPredictionJob.build());  | 
 | 97 | +      // Create the batch prediction job  | 
 | 98 | +      response =  | 
 | 99 | +          jobServiceClient.createBatchPredictionJob(parent, batchPredictionJob);  | 
100 | 100 | 
 
  | 
101 | 101 |       System.out.format("response: %s\n", response);  | 
102 | 102 |       System.out.format("\tName: %s\n", response.getName());  | 
103 | 103 |     }  | 
 | 104 | +    return response;  | 
104 | 105 |   }  | 
105 | 106 | 
 
  | 
106 |  | -  // Convert a Json string to a protobuf.Value  | 
107 |  | -  static Value stringToValue(String value) throws InvalidProtocolBufferException {  | 
 | 107 | +  private static Value mapToValue(Map<String, String> map) throws InvalidProtocolBufferException {  | 
 | 108 | +    Gson gson = new Gson();  | 
 | 109 | +    String json = gson.toJson(map);  | 
108 | 110 |     Value.Builder builder = Value.newBuilder();  | 
109 |  | -    JsonFormat.parser().merge(value, builder);  | 
 | 111 | +    JsonFormat.parser().merge(json, builder);  | 
110 | 112 |     return builder.build();  | 
111 | 113 |   }  | 
112 | 114 | }  | 
113 |  | -// [END aiplatform_batch_text_predict]  | 
 | 115 | +// [END generativeaionvertexai_batch_text_predict]  | 
0 commit comments