Skip to content

Commit 09700ae

Browse files
authored
fix: add back removed Imagen sample and test (#9545)
1 parent 07b6953 commit 09700ae

File tree

2 files changed

+182
-0
lines changed

2 files changed

+182
-0
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform.imagen;
18+
19+
// [START generativeaionvertexai_imagen_edit_image_mask_free]
20+
21+
import com.google.api.gax.rpc.ApiException;
22+
import com.google.cloud.aiplatform.v1.EndpointName;
23+
import com.google.cloud.aiplatform.v1.PredictResponse;
24+
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
25+
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
26+
import com.google.gson.Gson;
27+
import com.google.protobuf.InvalidProtocolBufferException;
28+
import com.google.protobuf.Value;
29+
import com.google.protobuf.util.JsonFormat;
30+
import java.io.IOException;
31+
import java.nio.file.Files;
32+
import java.nio.file.Path;
33+
import java.nio.file.Paths;
34+
import java.util.Base64;
35+
import java.util.Collections;
36+
import java.util.HashMap;
37+
import java.util.Map;
38+
39+
public class EditImageMaskFreeSample {
40+
41+
public static void main(String[] args) throws IOException {
42+
// TODO(developer): Replace these variables before running the sample.
43+
String projectId = "my-project-id";
44+
String location = "us-central1";
45+
String inputPath = "/path/to/my-input.png";
46+
String prompt = ""; // The text prompt describing what you want to see.
47+
48+
editImageMaskFree(projectId, location, inputPath, prompt);
49+
}
50+
51+
// Edit an image without using a mask. The edit is applied to the entire image and is saved to a
52+
// new file.
53+
public static PredictResponse editImageMaskFree(
54+
String projectId, String location, String inputPath, String prompt)
55+
throws ApiException, IOException {
56+
final String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
57+
PredictionServiceSettings predictionServiceSettings =
58+
PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build();
59+
60+
// Initialize client that will be used to send requests. This client only needs to be created
61+
// once, and can be reused for multiple requests.
62+
try (PredictionServiceClient predictionServiceClient =
63+
PredictionServiceClient.create(predictionServiceSettings)) {
64+
65+
final EndpointName endpointName =
66+
EndpointName.ofProjectLocationPublisherModelName(
67+
projectId, location, "google", "imagegeneration@002");
68+
69+
// Convert the image to Base64 and create the image map
70+
String imageBase64 =
71+
Base64.getEncoder().encodeToString(Files.readAllBytes(Paths.get(inputPath)));
72+
Map<String, String> imageMap = new HashMap<>();
73+
imageMap.put("bytesBase64Encoded", imageBase64);
74+
75+
Map<String, Object> instancesMap = new HashMap<>();
76+
instancesMap.put("prompt", prompt); // [ "prompt", "<my-prompt>" ]
77+
instancesMap.put(
78+
"image", imageMap); // [ "image", [ "bytesBase64Encoded", "iVBORw0KGgo...==" ] ]
79+
Value instances = mapToValue(instancesMap);
80+
81+
Map<String, Object> paramsMap = new HashMap<>();
82+
// Optional parameters
83+
paramsMap.put("seed", 1);
84+
// Controls the strength of the prompt.
85+
// 0-9 (low strength), 10-20 (medium strength), 21+ (high strength)
86+
paramsMap.put("guidanceScale", 21);
87+
paramsMap.put("sampleCount", 1);
88+
Value parameters = mapToValue(paramsMap);
89+
90+
PredictResponse predictResponse =
91+
predictionServiceClient.predict(
92+
endpointName, Collections.singletonList(instances), parameters);
93+
94+
for (Value prediction : predictResponse.getPredictionsList()) {
95+
Map<String, Value> fieldsMap = prediction.getStructValue().getFieldsMap();
96+
if (fieldsMap.containsKey("bytesBase64Encoded")) {
97+
String bytesBase64Encoded = fieldsMap.get("bytesBase64Encoded").getStringValue();
98+
Path tmpPath = Files.createTempFile("imagen-", ".png");
99+
Files.write(tmpPath, Base64.getDecoder().decode(bytesBase64Encoded));
100+
System.out.format("Image file written to: %s\n", tmpPath.toUri());
101+
}
102+
}
103+
return predictResponse;
104+
}
105+
}
106+
107+
private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
108+
Gson gson = new Gson();
109+
String json = gson.toJson(map);
110+
Value.Builder builder = Value.newBuilder();
111+
JsonFormat.parser().merge(json, builder);
112+
return builder.build();
113+
}
114+
}
115+
116+
// [END generativeaionvertexai_imagen_edit_image_mask_free]
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform.imagen;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static junit.framework.TestCase.assertNotNull;
21+
22+
import com.google.cloud.aiplatform.v1.PredictResponse;
23+
import com.google.protobuf.Value;
24+
import java.io.IOException;
25+
import java.util.Map;
26+
import org.junit.BeforeClass;
27+
import org.junit.Test;
28+
import org.junit.runner.RunWith;
29+
import org.junit.runners.JUnit4;
30+
31+
@RunWith(JUnit4.class)
32+
public class EditImageMaskFreeSampleTest {
33+
34+
private static final String PROJECT = System.getenv("GOOGLE_CLOUD_PROJECT");
35+
private static final String INPUT_FILE = "resources/cat.png";
36+
private static final String PROMPT = "a dog";
37+
38+
private static void requireEnvVar(String varName) {
39+
String errorMessage =
40+
String.format("Environment variable '%s' is required to perform these tests.", varName);
41+
assertNotNull(errorMessage, System.getenv(varName));
42+
}
43+
44+
@BeforeClass
45+
public static void checkRequirements() {
46+
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
47+
requireEnvVar("GOOGLE_CLOUD_PROJECT");
48+
}
49+
50+
@Test
51+
public void testEditImageMaskFreeSample() throws IOException {
52+
PredictResponse response =
53+
EditImageMaskFreeSample.editImageMaskFree(PROJECT, "us-central1", INPUT_FILE, PROMPT);
54+
assertThat(response).isNotNull();
55+
56+
Boolean imageBytes = false;
57+
for (Value prediction : response.getPredictionsList()) {
58+
Map<String, Value> fieldsMap = prediction.getStructValue().getFieldsMap();
59+
if (fieldsMap.containsKey("bytesBase64Encoded")) {
60+
imageBytes = true;
61+
break;
62+
}
63+
}
64+
assertThat(imageBytes).isTrue();
65+
}
66+
}

0 commit comments

Comments
 (0)