Skip to content

Commit 4d7c2a6

Browse files
Fixed test
2 parents 90bbb41 + c617f2a commit 4d7c2a6

27 files changed

+1372
-219
lines changed

aiplatform/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@
8989
<version>1.7.1</version>
9090
<scope>test</scope>
9191
</dependency>
92+
<dependency>
93+
<groupId>org.mockito</groupId>
94+
<artifactId>mockito-core</artifactId>
95+
<version>5.13.0</version>
96+
<scope>test</scope>
97+
</dependency>
9298
<dependency>
9399
<groupId>org.junit.jupiter</groupId>
94100
<artifactId>junit-jupiter</artifactId>
801 KB
Loading
3.17 KB
Loading
745 KB
Loading
11.2 KB
Loading

aiplatform/resources/woman.png

1.14 MB
Loading
23.9 KB
Loading
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
// [START generativeaionvertexai_gemma2_predict_gpu]
20+
21+
import com.google.cloud.aiplatform.v1.EndpointName;
22+
import com.google.cloud.aiplatform.v1.PredictResponse;
23+
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
24+
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
25+
import com.google.gson.Gson;
26+
import com.google.protobuf.InvalidProtocolBufferException;
27+
import com.google.protobuf.Value;
28+
import com.google.protobuf.util.JsonFormat;
29+
import java.io.IOException;
30+
import java.util.ArrayList;
31+
import java.util.HashMap;
32+
import java.util.List;
33+
import java.util.Map;
34+
35+
public class Gemma2PredictGpu {
36+
37+
private final PredictionServiceClient predictionServiceClient;
38+
39+
// Constructor to inject the PredictionServiceClient
40+
public Gemma2PredictGpu(PredictionServiceClient predictionServiceClient) {
41+
this.predictionServiceClient = predictionServiceClient;
42+
}
43+
44+
public static void main(String[] args) throws IOException {
45+
// TODO(developer): Replace these variables before running the sample.
46+
String projectId = "YOUR_PROJECT_ID";
47+
String endpointRegion = "us-east4";
48+
String endpointId = "YOUR_ENDPOINT_ID";
49+
50+
PredictionServiceSettings predictionServiceSettings =
51+
PredictionServiceSettings.newBuilder()
52+
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
53+
.build();
54+
PredictionServiceClient predictionServiceClient =
55+
PredictionServiceClient.create(predictionServiceSettings);
56+
Gemma2PredictGpu creator = new Gemma2PredictGpu(predictionServiceClient);
57+
58+
creator.gemma2PredictGpu(projectId, endpointRegion, endpointId);
59+
}
60+
61+
// Demonstrates how to run inference on a Gemma2 model
62+
// deployed to a Vertex AI endpoint with GPU accelerators.
63+
public String gemma2PredictGpu(String projectId, String region,
64+
String endpointId) throws IOException {
65+
Map<String, Object> paramsMap = new HashMap<>();
66+
paramsMap.put("temperature", 0.9);
67+
paramsMap.put("maxOutputTokens", 1024);
68+
paramsMap.put("topP", 1.0);
69+
paramsMap.put("topK", 1);
70+
Value parameters = mapToValue(paramsMap);
71+
72+
// Prompt used in the prediction
73+
String instance = "{ \"inputs\": \"Why is the sky blue?\"}";
74+
Value.Builder instanceValue = Value.newBuilder();
75+
JsonFormat.parser().merge(instance, instanceValue);
76+
// Encapsulate the prompt in a correct format for GPUs
77+
// Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}]
78+
List<Value> instances = new ArrayList<>();
79+
instances.add(instanceValue.build());
80+
81+
EndpointName endpointName = EndpointName.of(projectId, region, endpointId);
82+
83+
PredictResponse predictResponse = this.predictionServiceClient
84+
.predict(endpointName, instances, parameters);
85+
String textResponse = predictResponse.getPredictions(0).getStringValue();
86+
System.out.println(textResponse);
87+
return textResponse;
88+
}
89+
90+
private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
91+
Gson gson = new Gson();
92+
String json = gson.toJson(map);
93+
Value.Builder builder = Value.newBuilder();
94+
JsonFormat.parser().merge(json, builder);
95+
return builder.build();
96+
}
97+
}
98+
// [END generativeaionvertexai_gemma2_predict_gpu]
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
// [START generativeaionvertexai_gemma2_predict_tpu]
20+
21+
import com.google.cloud.aiplatform.v1.EndpointName;
22+
import com.google.cloud.aiplatform.v1.PredictResponse;
23+
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
24+
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
25+
import com.google.gson.Gson;
26+
import com.google.protobuf.InvalidProtocolBufferException;
27+
import com.google.protobuf.Value;
28+
import com.google.protobuf.util.JsonFormat;
29+
import java.io.IOException;
30+
import java.util.ArrayList;
31+
import java.util.HashMap;
32+
import java.util.List;
33+
import java.util.Map;
34+
35+
public class Gemma2PredictTpu {
36+
private final PredictionServiceClient predictionServiceClient;
37+
38+
// Constructor to inject the PredictionServiceClient
39+
public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
40+
this.predictionServiceClient = predictionServiceClient;
41+
}
42+
43+
public static void main(String[] args) throws IOException {
44+
// TODO(developer): Replace these variables before running the sample.
45+
String projectId = "YOUR_PROJECT_ID";
46+
String endpointRegion = "us-west1";
47+
String endpointId = "YOUR_ENDPOINT_ID";
48+
49+
PredictionServiceSettings predictionServiceSettings =
50+
PredictionServiceSettings.newBuilder()
51+
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
52+
.build();
53+
PredictionServiceClient predictionServiceClient =
54+
PredictionServiceClient.create(predictionServiceSettings);
55+
Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);
56+
57+
creator.gemma2PredictTpu(projectId, endpointRegion, endpointId);
58+
}
59+
60+
// Demonstrates how to run inference on a Gemma2 model
61+
// deployed to a Vertex AI endpoint with TPU accelerators.
62+
public String gemma2PredictTpu(String projectId, String region,
63+
String endpointId) throws IOException {
64+
Map<String, Object> paramsMap = new HashMap<>();
65+
paramsMap.put("temperature", 0.9);
66+
paramsMap.put("maxOutputTokens", 1024);
67+
paramsMap.put("topP", 1.0);
68+
paramsMap.put("topK", 1);
69+
Value parameters = mapToValue(paramsMap);
70+
// Prompt used in the prediction
71+
String instance = "{ \"prompt\": \"Why is the sky blue?\"}";
72+
Value.Builder instanceValue = Value.newBuilder();
73+
JsonFormat.parser().merge(instance, instanceValue);
74+
// Encapsulate the prompt in a correct format for TPUs
75+
// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
76+
List<Value> instances = new ArrayList<>();
77+
instances.add(instanceValue.build());
78+
79+
EndpointName endpointName = EndpointName.of(projectId, region, endpointId);
80+
81+
PredictResponse predictResponse = this.predictionServiceClient
82+
.predict(endpointName, instances, parameters);
83+
String textResponse = predictResponse.getPredictions(0).getStringValue();
84+
System.out.println(textResponse);
85+
return textResponse;
86+
}
87+
88+
private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
89+
Gson gson = new Gson();
90+
String json = gson.toJson(map);
91+
Value.Builder builder = Value.newBuilder();
92+
JsonFormat.parser().merge(json, builder);
93+
return builder.build();
94+
}
95+
}
96+
// [END generativeaionvertexai_gemma2_predict_tpu]
97+
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform.imagen;
18+
19+
// [START generativeaionvertexai_imagen_edit_image_inpainting_insert_mask]
20+
21+
import com.google.api.gax.rpc.ApiException;
22+
import com.google.cloud.aiplatform.v1.EndpointName;
23+
import com.google.cloud.aiplatform.v1.PredictResponse;
24+
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
25+
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
26+
import com.google.gson.Gson;
27+
import com.google.protobuf.InvalidProtocolBufferException;
28+
import com.google.protobuf.Value;
29+
import com.google.protobuf.util.JsonFormat;
30+
import java.io.IOException;
31+
import java.nio.file.Files;
32+
import java.nio.file.Path;
33+
import java.nio.file.Paths;
34+
import java.util.Base64;
35+
import java.util.Collections;
36+
import java.util.HashMap;
37+
import java.util.Map;
38+
39+
public class EditImageInpaintingInsertMaskSample {
40+
41+
public static void main(String[] args) throws IOException {
42+
// TODO(developer): Replace these variables before running the sample.
43+
String projectId = "my-project-id";
44+
String location = "us-central1";
45+
String inputPath = "/path/to/my-input.png";
46+
String maskPath = "/path/to/my-mask.png";
47+
String prompt =
48+
""; // The text prompt describing what you want to see inserted in the mask area.
49+
50+
editImageInpaintingInsertMask(projectId, location, inputPath, maskPath, prompt);
51+
}
52+
53+
// Edit an image using a mask file. Inpainting can insert the object designated by the prompt
54+
// into the masked area.
55+
public static PredictResponse editImageInpaintingInsertMask(
56+
String projectId, String location, String inputPath, String maskPath, String prompt)
57+
throws ApiException, IOException {
58+
final String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
59+
PredictionServiceSettings predictionServiceSettings =
60+
PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build();
61+
62+
// Initialize client that will be used to send requests. This client only needs to be created
63+
// once, and can be reused for multiple requests.
64+
try (PredictionServiceClient predictionServiceClient =
65+
PredictionServiceClient.create(predictionServiceSettings)) {
66+
67+
final EndpointName endpointName =
68+
EndpointName.ofProjectLocationPublisherModelName(
69+
projectId, location, "google", "imagegeneration@006");
70+
71+
// Encode image and mask to Base64
72+
String imageBase64 =
73+
Base64.getEncoder().encodeToString(Files.readAllBytes(Paths.get(inputPath)));
74+
String maskBase64 =
75+
Base64.getEncoder().encodeToString(Files.readAllBytes(Paths.get(maskPath)));
76+
77+
// Create the image and image mask maps
78+
Map<String, String> imageMap = new HashMap<>();
79+
imageMap.put("bytesBase64Encoded", imageBase64);
80+
81+
Map<String, String> maskMap = new HashMap<>();
82+
maskMap.put("bytesBase64Encoded", maskBase64);
83+
Map<String, Map> imageMaskMap = new HashMap<>();
84+
imageMaskMap.put("image", maskMap);
85+
86+
Map<String, Object> instancesMap = new HashMap<>();
87+
instancesMap.put("prompt", prompt); // [ "prompt", "<my-prompt>" ]
88+
instancesMap.put(
89+
"image", imageMap); // [ "image", [ "bytesBase64Encoded", "iVBORw0KGgo...==" ] ]
90+
instancesMap.put(
91+
"mask",
92+
imageMaskMap); // [ "mask", [ "image", [ "bytesBase64Encoded", "iJKDF0KGpl...==" ] ] ]
93+
instancesMap.put("editMode", "inpainting-insert"); // [ "editMode", "inpainting-insert" ]
94+
Value instances = mapToValue(instancesMap);
95+
96+
// Optional parameters
97+
Map<String, Object> paramsMap = new HashMap<>();
98+
paramsMap.put("sampleCount", 1);
99+
Value parameters = mapToValue(paramsMap);
100+
101+
PredictResponse predictResponse =
102+
predictionServiceClient.predict(
103+
endpointName, Collections.singletonList(instances), parameters);
104+
105+
for (Value prediction : predictResponse.getPredictionsList()) {
106+
Map<String, Value> fieldsMap = prediction.getStructValue().getFieldsMap();
107+
if (fieldsMap.containsKey("bytesBase64Encoded")) {
108+
String bytesBase64Encoded = fieldsMap.get("bytesBase64Encoded").getStringValue();
109+
Path tmpPath = Files.createTempFile("imagen-", ".png");
110+
Files.write(tmpPath, Base64.getDecoder().decode(bytesBase64Encoded));
111+
System.out.format("Image file written to: %s\n", tmpPath.toUri());
112+
}
113+
}
114+
return predictResponse;
115+
}
116+
}
117+
118+
private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
119+
Gson gson = new Gson();
120+
String json = gson.toJson(map);
121+
Value.Builder builder = Value.newBuilder();
122+
JsonFormat.parser().merge(json, builder);
123+
return builder.build();
124+
}
125+
}
126+
127+
// [END generativeaionvertexai_imagen_edit_image_inpainting_insert_mask]

0 commit comments

Comments
 (0)