Skip to content

Commit cd0b601

Browse files
irataxySita04
andauthored
feat: add Imagen watermarking and captions samples and tests (#9549)
* feat: add Imagen watermarking and captions samples and tests * add comment * remove parameter * Update aiplatform/src/main/java/aiplatform/imagen/VerifyImageWatermarkSample.java Co-authored-by: Sita Lakshmi Sangameswaran <[email protected]> * address feedback --------- Co-authored-by: Sita Lakshmi Sangameswaran <[email protected]>
1 parent a002e61 commit cd0b601

File tree

7 files changed

+509
-0
lines changed

7 files changed

+509
-0
lines changed
1.02 MB
Loading
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform.imagen;
18+
19+
// [START generativeaionvertexai_imagen_get_short_form_image_captions]
20+
21+
import com.google.api.gax.rpc.ApiException;
22+
import com.google.cloud.aiplatform.v1.EndpointName;
23+
import com.google.cloud.aiplatform.v1.PredictResponse;
24+
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
25+
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
26+
import com.google.gson.Gson;
27+
import com.google.protobuf.InvalidProtocolBufferException;
28+
import com.google.protobuf.Value;
29+
import com.google.protobuf.util.JsonFormat;
30+
import java.io.IOException;
31+
import java.nio.file.Files;
32+
import java.nio.file.Paths;
33+
import java.util.Base64;
34+
import java.util.Collections;
35+
import java.util.HashMap;
36+
import java.util.Map;
37+
38+
public class GetShortFormImageCaptionsSample {
39+
40+
public static void main(String[] args) throws IOException {
41+
// TODO(developer): Replace these variables before running the sample.
42+
String projectId = "my-project-id";
43+
String location = "us-central1";
44+
String inputPath = "/path/to/my-input.png";
45+
46+
getShortFormImageCaptions(projectId, location, inputPath);
47+
}
48+
49+
// Get the short form captions for an image
50+
public static PredictResponse getShortFormImageCaptions(
51+
String projectId, String location, String inputPath) throws ApiException, IOException {
52+
final String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
53+
PredictionServiceSettings predictionServiceSettings =
54+
PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build();
55+
56+
// Initialize client that will be used to send requests. This client only needs to be created
57+
// once, and can be reused for multiple requests.
58+
try (PredictionServiceClient predictionServiceClient =
59+
PredictionServiceClient.create(predictionServiceSettings)) {
60+
61+
final EndpointName endpointName =
62+
EndpointName.ofProjectLocationPublisherModelName(
63+
projectId, location, "google", "imagetext@001");
64+
65+
// Encode image to Base64
66+
String imageBase64 =
67+
Base64.getEncoder().encodeToString(Files.readAllBytes(Paths.get(inputPath)));
68+
69+
// Create the image map
70+
Map<String, String> imageMap = new HashMap<>();
71+
imageMap.put("bytesBase64Encoded", imageBase64);
72+
73+
Map<String, Object> instancesMap = new HashMap<>();
74+
instancesMap.put("image", imageMap);
75+
Value instances = mapToValue(instancesMap);
76+
77+
// Optional parameters
78+
Map<String, Object> paramsMap = new HashMap<>();
79+
paramsMap.put("language", "en");
80+
paramsMap.put("sampleCount", 2);
81+
Value parameters = mapToValue(paramsMap);
82+
83+
PredictResponse predictResponse =
84+
predictionServiceClient.predict(
85+
endpointName, Collections.singletonList(instances), parameters);
86+
87+
for (Value prediction : predictResponse.getPredictionsList()) {
88+
System.out.println(prediction.getStringValue());
89+
}
90+
return predictResponse;
91+
}
92+
}
93+
94+
private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
95+
Gson gson = new Gson();
96+
String json = gson.toJson(map);
97+
Value.Builder builder = Value.newBuilder();
98+
JsonFormat.parser().merge(json, builder);
99+
return builder.build();
100+
}
101+
}
102+
103+
// [END generativeaionvertexai_imagen_get_short_form_image_captions]
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform.imagen;
18+
19+
// [START generativeaionvertexai_imagen_get_short_form_image_responses]
20+
21+
import com.google.api.gax.rpc.ApiException;
22+
import com.google.cloud.aiplatform.v1.EndpointName;
23+
import com.google.cloud.aiplatform.v1.PredictResponse;
24+
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
25+
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
26+
import com.google.gson.Gson;
27+
import com.google.protobuf.InvalidProtocolBufferException;
28+
import com.google.protobuf.Value;
29+
import com.google.protobuf.util.JsonFormat;
30+
import java.io.IOException;
31+
import java.nio.file.Files;
32+
import java.nio.file.Paths;
33+
import java.util.Base64;
34+
import java.util.Collections;
35+
import java.util.HashMap;
36+
import java.util.Map;
37+
38+
public class GetShortFormImageResponsesSample {
39+
40+
public static void main(String[] args) throws IOException {
41+
// TODO(developer): Replace these variables before running the sample.
42+
String projectId = "my-project-id";
43+
String location = "us-central1";
44+
String inputPath = "/path/to/my-input.png";
45+
String prompt = ""; // The question about the contents of the image.
46+
47+
getShortFormImageResponses(projectId, location, inputPath, prompt);
48+
}
49+
50+
// Get the short form responses to a question about an image
51+
public static PredictResponse getShortFormImageResponses(
52+
String projectId, String location, String inputPath, String prompt)
53+
throws ApiException, IOException {
54+
final String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
55+
PredictionServiceSettings predictionServiceSettings =
56+
PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build();
57+
58+
// Initialize client that will be used to send requests. This client only needs to be created
59+
// once, and can be reused for multiple requests.
60+
try (PredictionServiceClient predictionServiceClient =
61+
PredictionServiceClient.create(predictionServiceSettings)) {
62+
63+
final EndpointName endpointName =
64+
EndpointName.ofProjectLocationPublisherModelName(
65+
projectId, location, "google", "imagetext@001");
66+
67+
// Encode image to Base64
68+
String imageBase64 =
69+
Base64.getEncoder().encodeToString(Files.readAllBytes(Paths.get(inputPath)));
70+
71+
// Create the image map
72+
Map<String, String> imageMap = new HashMap<>();
73+
imageMap.put("bytesBase64Encoded", imageBase64);
74+
75+
Map<String, Object> instancesMap = new HashMap<>();
76+
instancesMap.put("prompt", prompt);
77+
instancesMap.put("image", imageMap);
78+
Value instances = mapToValue(instancesMap);
79+
80+
// Optional parameters
81+
Map<String, Object> paramsMap = new HashMap<>();
82+
paramsMap.put("sampleCount", 2);
83+
Value parameters = mapToValue(paramsMap);
84+
85+
PredictResponse predictResponse =
86+
predictionServiceClient.predict(
87+
endpointName, Collections.singletonList(instances), parameters);
88+
89+
for (Value prediction : predictResponse.getPredictionsList()) {
90+
System.out.println(prediction.getStringValue());
91+
}
92+
return predictResponse;
93+
}
94+
}
95+
96+
private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
97+
Gson gson = new Gson();
98+
String json = gson.toJson(map);
99+
Value.Builder builder = Value.newBuilder();
100+
JsonFormat.parser().merge(json, builder);
101+
return builder.build();
102+
}
103+
}
104+
105+
// [END generativeaionvertexai_imagen_get_short_form_image_responses]
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform.imagen;
18+
19+
// [START generativeaionvertexai_imagen_verify_image_watermark]
20+
21+
import com.google.api.gax.rpc.ApiException;
22+
import com.google.cloud.aiplatform.v1.EndpointName;
23+
import com.google.cloud.aiplatform.v1.PredictResponse;
24+
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
25+
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
26+
import com.google.gson.Gson;
27+
import com.google.protobuf.InvalidProtocolBufferException;
28+
import com.google.protobuf.Value;
29+
import com.google.protobuf.util.JsonFormat;
30+
import java.io.IOException;
31+
import java.nio.file.Files;
32+
import java.nio.file.Paths;
33+
import java.util.Base64;
34+
import java.util.Collections;
35+
import java.util.HashMap;
36+
import java.util.Map;
37+
38+
public class VerifyImageWatermarkSample {
39+
40+
public static void main(String[] args) throws IOException {
41+
// TODO(developer): Replace these variables before running the sample.
42+
String projectId = "my-project-id";
43+
String location = "us-central1";
44+
String inputPath = "/path/to/my-input.png";
45+
46+
verifyImageWatermark(projectId, location, inputPath);
47+
}
48+
49+
// Verify if an image contains a digital watermark. By default, a non-visible, digital watermark
50+
// (called a SynthID) is added to images generated by a model version that supports
51+
// watermark generation.
52+
public static PredictResponse verifyImageWatermark(
53+
String projectId, String location, String inputPath) throws ApiException, IOException {
54+
final String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
55+
PredictionServiceSettings predictionServiceSettings =
56+
PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build();
57+
58+
// Initialize client that will be used to send requests. This client only needs to be created
59+
// once, and can be reused for multiple requests.
60+
try (PredictionServiceClient predictionServiceClient =
61+
PredictionServiceClient.create(predictionServiceSettings)) {
62+
63+
final EndpointName endpointName =
64+
EndpointName.ofProjectLocationPublisherModelName(
65+
projectId, location, "google", "imageverification@001");
66+
67+
// Encode image to Base64
68+
String imageBase64 =
69+
Base64.getEncoder().encodeToString(Files.readAllBytes(Paths.get(inputPath)));
70+
71+
// Create the image map
72+
Map<String, String> imageMap = new HashMap<>();
73+
imageMap.put("bytesBase64Encoded", imageBase64);
74+
75+
Map<String, Object> instancesMap = new HashMap<>();
76+
instancesMap.put("image", imageMap);
77+
Value instances = mapToValue(instancesMap);
78+
79+
// Optional parameters
80+
Map<String, Object> paramsMap = new HashMap<>();
81+
Value parameters = mapToValue(paramsMap);
82+
83+
PredictResponse predictResponse =
84+
predictionServiceClient.predict(
85+
endpointName, Collections.singletonList(instances), parameters);
86+
87+
for (Value prediction : predictResponse.getPredictionsList()) {
88+
Map<String, Value> fieldsMap = prediction.getStructValue().getFieldsMap();
89+
if (fieldsMap.containsKey("decision")) {
90+
// "ACCEPT" if the image contains a digital watermark
91+
// "REJECT" if the image does not contain a digital watermark
92+
System.out.format(
93+
"Watermark verification result: %s", fieldsMap.get("decision").getStringValue());
94+
}
95+
}
96+
return predictResponse;
97+
}
98+
}
99+
100+
private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
101+
Gson gson = new Gson();
102+
String json = gson.toJson(map);
103+
Value.Builder builder = Value.newBuilder();
104+
JsonFormat.parser().merge(json, builder);
105+
return builder.build();
106+
}
107+
}
108+
109+
// [END generativeaionvertexai_imagen_verify_image_watermark]
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform.imagen;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static junit.framework.TestCase.assertNotNull;
21+
22+
import com.google.cloud.aiplatform.v1.PredictResponse;
23+
import com.google.protobuf.Value;
24+
import java.io.IOException;
25+
import org.junit.BeforeClass;
26+
import org.junit.Test;
27+
import org.junit.runner.RunWith;
28+
import org.junit.runners.JUnit4;
29+
30+
@RunWith(JUnit4.class)
31+
public class GetShortFormImageCaptionsSampleTest {
32+
33+
private static final String PROJECT = System.getenv("GOOGLE_CLOUD_PROJECT");
34+
private static final String INPUT_FILE = "resources/cat.png";
35+
36+
private static void requireEnvVar(String varName) {
37+
String errorMessage =
38+
String.format("Environment variable '%s' is required to perform these tests.", varName);
39+
assertNotNull(errorMessage, System.getenv(varName));
40+
}
41+
42+
@BeforeClass
43+
public static void checkRequirements() {
44+
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
45+
requireEnvVar("GOOGLE_CLOUD_PROJECT");
46+
}
47+
48+
@Test
49+
public void testGetShortFormImageCaptionsSample() throws IOException {
50+
PredictResponse response =
51+
GetShortFormImageCaptionsSample.getShortFormImageCaptions(
52+
PROJECT, "us-central1", INPUT_FILE);
53+
assertThat(response).isNotNull();
54+
55+
for (Value prediction : response.getPredictionsList()) {
56+
assertThat(prediction.getStringValue().contains("cat"));
57+
}
58+
}
59+
}

0 commit comments

Comments
 (0)