Skip to content

Commit 07b6953

Browse files
authored
feat: add Imagen inpainting and outpainting samples and tests (#9524)
* feat: add Imagen inpainting and outpainting samples and tests * fix lint * address feedback * fix commented out files * address feedback
1 parent 39ee185 commit 07b6953

12 files changed

+427
-27
lines changed
801 KB
Loading
3.17 KB
Loading
745 KB
Loading
11.2 KB
Loading

aiplatform/resources/woman.png

1.14 MB
Loading
23.9 KB
Loading
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform.imagen;
18+
19+
// [START generativeaionvertexai_imagen_edit_image_inpainting_insert_mask]
20+
21+
import com.google.api.gax.rpc.ApiException;
22+
import com.google.cloud.aiplatform.v1.EndpointName;
23+
import com.google.cloud.aiplatform.v1.PredictResponse;
24+
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
25+
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
26+
import com.google.gson.Gson;
27+
import com.google.protobuf.InvalidProtocolBufferException;
28+
import com.google.protobuf.Value;
29+
import com.google.protobuf.util.JsonFormat;
30+
import java.io.IOException;
31+
import java.nio.file.Files;
32+
import java.nio.file.Path;
33+
import java.nio.file.Paths;
34+
import java.util.Base64;
35+
import java.util.Collections;
36+
import java.util.HashMap;
37+
import java.util.Map;
38+
39+
public class EditImageInpaintingInsertMaskSample {
40+
41+
public static void main(String[] args) throws IOException {
42+
// TODO(developer): Replace these variables before running the sample.
43+
String projectId = "my-project-id";
44+
String location = "us-central1";
45+
String inputPath = "/path/to/my-input.png";
46+
String maskPath = "/path/to/my-mask.png";
47+
String prompt =
48+
""; // The text prompt describing what you want to see inserted in the mask area.
49+
50+
editImageInpaintingInsertMask(projectId, location, inputPath, maskPath, prompt);
51+
}
52+
53+
// Edit an image using a mask file. Inpainting can insert the object designated by the prompt
54+
// into the masked area.
55+
public static PredictResponse editImageInpaintingInsertMask(
56+
String projectId, String location, String inputPath, String maskPath, String prompt)
57+
throws ApiException, IOException {
58+
final String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
59+
PredictionServiceSettings predictionServiceSettings =
60+
PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build();
61+
62+
// Initialize client that will be used to send requests. This client only needs to be created
63+
// once, and can be reused for multiple requests.
64+
try (PredictionServiceClient predictionServiceClient =
65+
PredictionServiceClient.create(predictionServiceSettings)) {
66+
67+
final EndpointName endpointName =
68+
EndpointName.ofProjectLocationPublisherModelName(
69+
projectId, location, "google", "imagegeneration@006");
70+
71+
// Encode image and mask to Base64
72+
String imageBase64 =
73+
Base64.getEncoder().encodeToString(Files.readAllBytes(Paths.get(inputPath)));
74+
String maskBase64 =
75+
Base64.getEncoder().encodeToString(Files.readAllBytes(Paths.get(maskPath)));
76+
77+
// Create the image and image mask maps
78+
Map<String, String> imageMap = new HashMap<>();
79+
imageMap.put("bytesBase64Encoded", imageBase64);
80+
81+
Map<String, String> maskMap = new HashMap<>();
82+
maskMap.put("bytesBase64Encoded", maskBase64);
83+
Map<String, Map> imageMaskMap = new HashMap<>();
84+
imageMaskMap.put("image", maskMap);
85+
86+
Map<String, Object> instancesMap = new HashMap<>();
87+
instancesMap.put("prompt", prompt); // [ "prompt", "<my-prompt>" ]
88+
instancesMap.put(
89+
"image", imageMap); // [ "image", [ "bytesBase64Encoded", "iVBORw0KGgo...==" ] ]
90+
instancesMap.put(
91+
"mask",
92+
imageMaskMap); // [ "mask", [ "image", [ "bytesBase64Encoded", "iJKDF0KGpl...==" ] ] ]
93+
instancesMap.put("editMode", "inpainting-insert"); // [ "editMode", "inpainting-insert" ]
94+
Value instances = mapToValue(instancesMap);
95+
96+
// Optional parameters
97+
Map<String, Object> paramsMap = new HashMap<>();
98+
paramsMap.put("sampleCount", 1);
99+
Value parameters = mapToValue(paramsMap);
100+
101+
PredictResponse predictResponse =
102+
predictionServiceClient.predict(
103+
endpointName, Collections.singletonList(instances), parameters);
104+
105+
for (Value prediction : predictResponse.getPredictionsList()) {
106+
Map<String, Value> fieldsMap = prediction.getStructValue().getFieldsMap();
107+
if (fieldsMap.containsKey("bytesBase64Encoded")) {
108+
String bytesBase64Encoded = fieldsMap.get("bytesBase64Encoded").getStringValue();
109+
Path tmpPath = Files.createTempFile("imagen-", ".png");
110+
Files.write(tmpPath, Base64.getDecoder().decode(bytesBase64Encoded));
111+
System.out.format("Image file written to: %s\n", tmpPath.toUri());
112+
}
113+
}
114+
return predictResponse;
115+
}
116+
}
117+
118+
private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
119+
Gson gson = new Gson();
120+
String json = gson.toJson(map);
121+
Value.Builder builder = Value.newBuilder();
122+
JsonFormat.parser().merge(json, builder);
123+
return builder.build();
124+
}
125+
}
126+
127+
// [END generativeaionvertexai_imagen_edit_image_inpainting_insert_mask]

aiplatform/src/main/java/aiplatform/imagen/EditImageMaskFreeSample.java renamed to aiplatform/src/main/java/aiplatform/imagen/EditImageInpaintingRemoveMaskSample.java

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
package aiplatform.imagen;
1818

19-
// [START generativeaionvertexai_imagen_edit_image_mask_free]
19+
// [START generativeaionvertexai_imagen_edit_image_inpainting_remove_mask]
2020

2121
import com.google.api.gax.rpc.ApiException;
2222
import com.google.cloud.aiplatform.v1.EndpointName;
@@ -28,7 +28,6 @@
2828
import com.google.protobuf.Value;
2929
import com.google.protobuf.util.JsonFormat;
3030
import java.io.IOException;
31-
import java.nio.charset.StandardCharsets;
3231
import java.nio.file.Files;
3332
import java.nio.file.Path;
3433
import java.nio.file.Paths;
@@ -37,22 +36,22 @@
3736
import java.util.HashMap;
3837
import java.util.Map;
3938

40-
public class EditImageMaskFreeSample {
39+
public class EditImageInpaintingRemoveMaskSample {
4140

4241
public static void main(String[] args) throws IOException {
4342
// TODO(developer): Replace these variables before running the sample.
4443
String projectId = "my-project-id";
4544
String location = "us-central1";
4645
String inputPath = "/path/to/my-input.png";
47-
String prompt = ""; // The text prompt describing what you want to see.
46+
String maskPath = "/path/to/my-mask.png";
47+
String prompt = ""; // The text prompt describing the entire image.
4848

49-
editImageMaskFree(projectId, location, inputPath, prompt);
49+
editImageInpaintingRemoveMask(projectId, location, inputPath, maskPath, prompt);
5050
}
5151

52-
// Edit an image without using a mask. The edit is applied to the entire image and is saved to a
53-
// new file.
54-
public static PredictResponse editImageMaskFree(
55-
String projectId, String location, String inputPath, String prompt)
52+
// Edit an image using a mask file. Inpainting can remove an object from the masked area.
53+
public static PredictResponse editImageInpaintingRemoveMask(
54+
String projectId, String location, String inputPath, String maskPath, String prompt)
5655
throws ApiException, IOException {
5756
final String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
5857
PredictionServiceSettings predictionServiceSettings =
@@ -65,25 +64,35 @@ public static PredictResponse editImageMaskFree(
6564

6665
final EndpointName endpointName =
6766
EndpointName.ofProjectLocationPublisherModelName(
68-
projectId, location, "google", "imagegeneration@002");
67+
projectId, location, "google", "imagegeneration@006");
6968

70-
// Convert the image to Base64.
71-
byte[] imageData = Base64.getEncoder().encode(Files.readAllBytes(Paths.get(inputPath)));
72-
String image = new String(imageData, StandardCharsets.UTF_8);
69+
// Encode image and mask to Base64
70+
String imageBase64 =
71+
Base64.getEncoder().encodeToString(Files.readAllBytes(Paths.get(inputPath)));
72+
String maskBase64 =
73+
Base64.getEncoder().encodeToString(Files.readAllBytes(Paths.get(maskPath)));
74+
75+
// Create the image and image mask maps
7376
Map<String, String> imageMap = new HashMap<>();
74-
imageMap.put("bytesBase64Encoded", image);
77+
imageMap.put("bytesBase64Encoded", imageBase64);
78+
79+
Map<String, String> maskMap = new HashMap<>();
80+
maskMap.put("bytesBase64Encoded", maskBase64);
81+
Map<String, Map> imageMaskMap = new HashMap<>();
82+
imageMaskMap.put("image", maskMap);
7583

7684
Map<String, Object> instancesMap = new HashMap<>();
77-
instancesMap.put("prompt", prompt);
78-
instancesMap.put("image", imageMap);
85+
instancesMap.put("prompt", prompt); // [ "prompt", "<my-prompt>" ]
86+
instancesMap.put(
87+
"image", imageMap); // [ "image", [ "bytesBase64Encoded", "iVBORw0KGgo...==" ] ]
88+
instancesMap.put(
89+
"mask",
90+
imageMaskMap); // [ "mask", [ "image", [ "bytesBase64Encoded", "iJKDF0KGpl...==" ] ] ]
91+
instancesMap.put("editMode", "inpainting-remove"); // [ "editMode", "inpainting-remove" ]
7992
Value instances = mapToValue(instancesMap);
8093

81-
Map<String, Object> paramsMap = new HashMap<>();
8294
// Optional parameters
83-
paramsMap.put("seed", 1);
84-
// Controls the strength of the prompt.
85-
// 0-9 (low strength), 10-20 (medium strength), 21+ (high strength)
86-
paramsMap.put("guidanceScale", 21);
95+
Map<String, Object> paramsMap = new HashMap<>();
8796
paramsMap.put("sampleCount", 1);
8897
Value parameters = mapToValue(paramsMap);
8998

@@ -113,4 +122,4 @@ private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolB
113122
}
114123
}
115124

116-
// [END generativeaionvertexai_imagen_edit_image_mask_free]
125+
// [END generativeaionvertexai_imagen_edit_image_inpainting_remove_mask]
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform.imagen;
18+
19+
// [START generativeaionvertexai_imagen_edit_image_outpainting_mask]
20+
21+
import com.google.api.gax.rpc.ApiException;
22+
import com.google.cloud.aiplatform.v1.EndpointName;
23+
import com.google.cloud.aiplatform.v1.PredictResponse;
24+
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
25+
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
26+
import com.google.gson.Gson;
27+
import com.google.protobuf.InvalidProtocolBufferException;
28+
import com.google.protobuf.Value;
29+
import com.google.protobuf.util.JsonFormat;
30+
import java.io.IOException;
31+
import java.nio.file.Files;
32+
import java.nio.file.Path;
33+
import java.nio.file.Paths;
34+
import java.util.Base64;
35+
import java.util.Collections;
36+
import java.util.HashMap;
37+
import java.util.Map;
38+
39+
public class EditImageOutpaintingMaskSample {
40+
41+
public static void main(String[] args) throws IOException {
42+
// TODO(developer): Replace these variables before running the sample.
43+
String projectId = "my-project-id";
44+
String location = "us-central1";
45+
String inputPath = "/path/to/my-input.png";
46+
String maskPath = "/path/to/my-mask.png";
47+
String prompt = ""; // The optional text prompt describing what you want to see inserted.
48+
49+
editImageOutpaintingMask(projectId, location, inputPath, maskPath, prompt);
50+
}
51+
52+
// Edit an image using a mask file. Outpainting lets you expand the content of a base image to fit
53+
// a larger or differently sized mask canvas.
54+
public static PredictResponse editImageOutpaintingMask(
55+
String projectId, String location, String inputPath, String maskPath, String prompt)
56+
throws ApiException, IOException {
57+
final String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
58+
PredictionServiceSettings predictionServiceSettings =
59+
PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build();
60+
61+
// Initialize client that will be used to send requests. This client only needs to be created
62+
// once, and can be reused for multiple requests.
63+
try (PredictionServiceClient predictionServiceClient =
64+
PredictionServiceClient.create(predictionServiceSettings)) {
65+
66+
final EndpointName endpointName =
67+
EndpointName.ofProjectLocationPublisherModelName(
68+
projectId, location, "google", "imagegeneration@006");
69+
70+
// Encode image and mask to Base64
71+
String imageBase64 =
72+
Base64.getEncoder().encodeToString(Files.readAllBytes(Paths.get(inputPath)));
73+
String maskBase64 =
74+
Base64.getEncoder().encodeToString(Files.readAllBytes(Paths.get(maskPath)));
75+
76+
// Create the image and image mask maps
77+
Map<String, String> imageMap = new HashMap<>();
78+
imageMap.put("bytesBase64Encoded", imageBase64);
79+
80+
Map<String, String> maskMap = new HashMap<>();
81+
maskMap.put("bytesBase64Encoded", maskBase64);
82+
Map<String, Map> imageMaskMap = new HashMap<>();
83+
imageMaskMap.put("image", maskMap);
84+
85+
Map<String, Object> instancesMap = new HashMap<>();
86+
instancesMap.put("prompt", prompt); // [ "prompt", "<my-prompt>" ]
87+
instancesMap.put(
88+
"image", imageMap); // [ "image", [ "bytesBase64Encoded", "iVBORw0KGgo...==" ] ]
89+
instancesMap.put(
90+
"mask",
91+
imageMaskMap); // [ "mask", [ "image", [ "bytesBase64Encoded", "iJKDF0KGpl...==" ] ] ]
92+
instancesMap.put("editMode", "outpainting"); // [ "editMode", "outpainting" ]
93+
Value instances = mapToValue(instancesMap);
94+
95+
// Optional parameters
96+
Map<String, Object> paramsMap = new HashMap<>();
97+
paramsMap.put("sampleCount", 1);
98+
Value parameters = mapToValue(paramsMap);
99+
100+
PredictResponse predictResponse =
101+
predictionServiceClient.predict(
102+
endpointName, Collections.singletonList(instances), parameters);
103+
104+
for (Value prediction : predictResponse.getPredictionsList()) {
105+
Map<String, Value> fieldsMap = prediction.getStructValue().getFieldsMap();
106+
if (fieldsMap.containsKey("bytesBase64Encoded")) {
107+
String bytesBase64Encoded = fieldsMap.get("bytesBase64Encoded").getStringValue();
108+
Path tmpPath = Files.createTempFile("imagen-", ".png");
109+
Files.write(tmpPath, Base64.getDecoder().decode(bytesBase64Encoded));
110+
System.out.format("Image file written to: %s\n", tmpPath.toUri());
111+
}
112+
}
113+
return predictResponse;
114+
}
115+
}
116+
117+
private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
118+
Gson gson = new Gson();
119+
String json = gson.toJson(map);
120+
Value.Builder builder = Value.newBuilder();
121+
JsonFormat.parser().merge(json, builder);
122+
return builder.build();
123+
}
124+
}
125+
126+
// [END generativeaionvertexai_imagen_edit_image_outpainting_mask]

0 commit comments

Comments
 (0)