diff --git a/aiplatform/pom.xml b/aiplatform/pom.xml
index 25f989557bf..14e314a1244 100644
--- a/aiplatform/pom.xml
+++ b/aiplatform/pom.xml
@@ -89,6 +89,12 @@
1.7.1
test
+
+ org.mockito
+ mockito-core
+ 5.13.0
+ test
+
org.junit.jupiter
junit-jupiter
diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java
new file mode 100644
index 00000000000..2c3b6c7dace
--- /dev/null
+++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java
@@ -0,0 +1,98 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START generativeaionvertexai_gemma2_predict_gpu]
+
+import com.google.cloud.aiplatform.v1.EndpointName;
+import com.google.cloud.aiplatform.v1.PredictResponse;
+import com.google.cloud.aiplatform.v1.PredictionServiceClient;
+import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
+import com.google.gson.Gson;
+import com.google.protobuf.InvalidProtocolBufferException;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class Gemma2PredictGpu {
+
+ private final PredictionServiceClient predictionServiceClient;
+
+ // Constructor to inject the PredictionServiceClient
+ public Gemma2PredictGpu(PredictionServiceClient predictionServiceClient) {
+ this.predictionServiceClient = predictionServiceClient;
+ }
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String projectId = "YOUR_PROJECT_ID";
+ String endpointRegion = "us-east4";
+ String endpointId = "YOUR_ENDPOINT_ID";
+
+ PredictionServiceSettings predictionServiceSettings =
+ PredictionServiceSettings.newBuilder()
+ .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
+ .build();
+ PredictionServiceClient predictionServiceClient =
+ PredictionServiceClient.create(predictionServiceSettings);
+ Gemma2PredictGpu creator = new Gemma2PredictGpu(predictionServiceClient);
+
+ creator.gemma2PredictGpu(projectId, endpointRegion, endpointId);
+ }
+
+ // Demonstrates how to run inference on a Gemma2 model
+ // deployed to a Vertex AI endpoint with GPU accelerators.
+ public String gemma2PredictGpu(String projectId, String region,
+ String endpointId) throws IOException {
+ Map paramsMap = new HashMap<>();
+ paramsMap.put("temperature", 0.9);
+ paramsMap.put("maxOutputTokens", 1024);
+ paramsMap.put("topP", 1.0);
+ paramsMap.put("topK", 1);
+ Value parameters = mapToValue(paramsMap);
+
+ // Prompt used in the prediction
+ String instance = "{ \"inputs\": \"Why is the sky blue?\"}";
+ Value.Builder instanceValue = Value.newBuilder();
+ JsonFormat.parser().merge(instance, instanceValue);
+ // Encapsulate the prompt in a correct format for GPUs
+ // Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}]
+ List instances = new ArrayList<>();
+ instances.add(instanceValue.build());
+
+ EndpointName endpointName = EndpointName.of(projectId, region, endpointId);
+
+ PredictResponse predictResponse = this.predictionServiceClient
+ .predict(endpointName, instances, parameters);
+ String textResponse = predictResponse.getPredictions(0).getStringValue();
+ System.out.println(textResponse);
+ return textResponse;
+ }
+
+ private static Value mapToValue(Map map) throws InvalidProtocolBufferException {
+ Gson gson = new Gson();
+ String json = gson.toJson(map);
+ Value.Builder builder = Value.newBuilder();
+ JsonFormat.parser().merge(json, builder);
+ return builder.build();
+ }
+}
+// [END generativeaionvertexai_gemma2_predict_gpu]
\ No newline at end of file
diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java
new file mode 100644
index 00000000000..de29b1cc111
--- /dev/null
+++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java
@@ -0,0 +1,97 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+// [START generativeaionvertexai_gemma2_predict_tpu]
+
+import com.google.cloud.aiplatform.v1.EndpointName;
+import com.google.cloud.aiplatform.v1.PredictResponse;
+import com.google.cloud.aiplatform.v1.PredictionServiceClient;
+import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
+import com.google.gson.Gson;
+import com.google.protobuf.InvalidProtocolBufferException;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class Gemma2PredictTpu {
+ private final PredictionServiceClient predictionServiceClient;
+
+ // Constructor to inject the PredictionServiceClient
+ public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
+ this.predictionServiceClient = predictionServiceClient;
+ }
+
+ public static void main(String[] args) throws IOException {
+ // TODO(developer): Replace these variables before running the sample.
+ String projectId = "YOUR_PROJECT_ID";
+ String endpointRegion = "us-west1";
+ String endpointId = "YOUR_ENDPOINT_ID";
+
+ PredictionServiceSettings predictionServiceSettings =
+ PredictionServiceSettings.newBuilder()
+ .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
+ .build();
+ PredictionServiceClient predictionServiceClient =
+ PredictionServiceClient.create(predictionServiceSettings);
+ Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);
+
+ creator.gemma2PredictTpu(projectId, endpointRegion, endpointId);
+ }
+
+ // Demonstrates how to run inference on a Gemma2 model
+ // deployed to a Vertex AI endpoint with TPU accelerators.
+ public String gemma2PredictTpu(String projectId, String region,
+ String endpointId) throws IOException {
+ Map paramsMap = new HashMap<>();
+ paramsMap.put("temperature", 0.9);
+ paramsMap.put("maxOutputTokens", 1024);
+ paramsMap.put("topP", 1.0);
+ paramsMap.put("topK", 1);
+ Value parameters = mapToValue(paramsMap);
+ // Prompt used in the prediction
+ String instance = "{ \"prompt\": \"Why is the sky blue?\"}";
+ Value.Builder instanceValue = Value.newBuilder();
+ JsonFormat.parser().merge(instance, instanceValue);
+ // Encapsulate the prompt in a correct format for TPUs
+ // Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
+ List instances = new ArrayList<>();
+ instances.add(instanceValue.build());
+
+ EndpointName endpointName = EndpointName.of(projectId, region, endpointId);
+
+ PredictResponse predictResponse = this.predictionServiceClient
+ .predict(endpointName, instances, parameters);
+ String textResponse = predictResponse.getPredictions(0).getStringValue();
+ System.out.println(textResponse);
+ return textResponse;
+ }
+
+ private static Value mapToValue(Map map) throws InvalidProtocolBufferException {
+ Gson gson = new Gson();
+ String json = gson.toJson(map);
+ Value.Builder builder = Value.newBuilder();
+ JsonFormat.parser().merge(json, builder);
+ return builder.build();
+ }
+}
+// [END generativeaionvertexai_gemma2_predict_tpu]
+
diff --git a/aiplatform/src/test/java/aiplatform/Gemma2ParametersTest.java b/aiplatform/src/test/java/aiplatform/Gemma2ParametersTest.java
new file mode 100644
index 00000000000..300eee49c93
--- /dev/null
+++ b/aiplatform/src/test/java/aiplatform/Gemma2ParametersTest.java
@@ -0,0 +1,108 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import com.google.cloud.aiplatform.v1.EndpointName;
+import com.google.cloud.aiplatform.v1.PredictionServiceClient;
+import com.google.gson.Gson;
+import com.google.protobuf.InvalidProtocolBufferException;
+import com.google.protobuf.Value;
+import com.google.protobuf.util.JsonFormat;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+import org.mockito.stubbing.Answer;
+
+public class Gemma2ParametersTest {
+
+ static PredictionServiceClient mockGpuPredictionServiceClient;
+ static PredictionServiceClient mockTpuPredictionServiceClient;
+ private static final String INSTANCE_GPU = "{ \"inputs\": \"Why is the sky blue?\"}";
+ private static final String INSTANCE_TPU = "{ \"prompt\": \"Why is the sky blue?\"}";
+
+ @Test
+ public void parametersTest() throws InvalidProtocolBufferException {
+ // Mock GPU and TPU PredictionServiceClient and its response
+ mockGpuPredictionServiceClient = Mockito.mock(PredictionServiceClient.class);
+ mockTpuPredictionServiceClient = Mockito.mock(PredictionServiceClient.class);
+
+ Value.Builder instanceValueGpu = Value.newBuilder();
+ JsonFormat.parser().merge(INSTANCE_GPU, instanceValueGpu);
+ List instancesGpu = new ArrayList<>();
+ instancesGpu.add(instanceValueGpu.build());
+
+ Value.Builder instanceValueTpu = Value.newBuilder();
+ JsonFormat.parser().merge(INSTANCE_TPU, instanceValueTpu);
+ List instancesTpu = new ArrayList<>();
+ instancesTpu.add(instanceValueTpu.build());
+
+ Map paramsMap = new HashMap<>();
+ paramsMap.put("temperature", 0.9);
+ paramsMap.put("maxOutputTokens", 1024);
+ paramsMap.put("topP", 1.0);
+ paramsMap.put("topK", 1);
+ Value parameters = mapToValue(paramsMap);
+
+ Mockito.when(mockGpuPredictionServiceClient.predict(
+ Mockito.any(EndpointName.class),
+ Mockito.any(List.class),
+ Mockito.any(Value.class)))
+ .thenAnswer(invocation ->
+ mockGpuResponse(instancesGpu, parameters));
+
+ Mockito.when(mockTpuPredictionServiceClient.predict(
+ Mockito.any(EndpointName.class),
+ Mockito.any(List.class),
+ Mockito.any(Value.class)))
+ .thenAnswer(invocation ->
+ mockTpuResponse(instancesTpu, parameters));
+ }
+
+ public static Answer> mockGpuResponse(List instances, Value parameter) {
+
+ assertTrue(instances.get(0).getStructValue().getFieldsMap().containsKey("inputs"));
+ assertTrue(parameter.getStructValue().containsFields("temperature"));
+ assertTrue(parameter.getStructValue().containsFields("maxOutputTokens"));
+ assertTrue(parameter.getStructValue().containsFields("topP"));
+ assertTrue(parameter.getStructValue().containsFields("topK"));
+ return null;
+ }
+
+ public static Answer> mockTpuResponse(List instances, Value parameter) {
+
+ assertTrue(instances.get(0).getStructValue().getFieldsMap().containsKey("prompt"));
+ assertTrue(parameter.getStructValue().containsFields("temperature"));
+ assertTrue(parameter.getStructValue().containsFields("maxOutputTokens"));
+ assertTrue(parameter.getStructValue().containsFields("topP"));
+ assertTrue(parameter.getStructValue().containsFields("topK"));
+ return null;
+ }
+
+ private static Value mapToValue(Map map) throws InvalidProtocolBufferException {
+ Gson gson = new Gson();
+ String json = gson.toJson(map);
+ Value.Builder builder = Value.newBuilder();
+ JsonFormat.parser().merge(json, builder);
+ return builder.build();
+ }
+}
+
diff --git a/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java b/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java
new file mode 100644
index 00000000000..9a78695a2a4
--- /dev/null
+++ b/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java
@@ -0,0 +1,75 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package aiplatform;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import com.google.cloud.aiplatform.v1.EndpointName;
+import com.google.cloud.aiplatform.v1.PredictResponse;
+import com.google.cloud.aiplatform.v1.PredictionServiceClient;
+import com.google.protobuf.Value;
+import java.io.IOException;
+import java.util.List;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+
+public class Gemma2PredictTest {
+ static String mockedResponse = "The sky appears blue due to a phenomenon "
+ + "called **Rayleigh scattering**.\n"
+ + "**Here's how it works:**\n"
+ + "* **Sunlight is white:** Sunlight actually contains all the colors of the rainbow.\n"
+ + "* **Scattering:** When sunlight enters the Earth's atmosphere, it collides with tiny gas"
+ + " molecules (mostly nitrogen and oxygen). These collisions cause the light to scatter "
+ + "in different directions.\n"
+ + "* **Blue light scatters most:** Blue light has a shorter wavelength";
+ String projectId = "your-project-id";
+ String region = "us-central1";
+ String endpointId = "your-endpoint-id";
+ static PredictionServiceClient mockPredictionServiceClient;
+
+ @BeforeAll
+ public static void setUp() {
+ // Mock PredictionServiceClient and its response
+ mockPredictionServiceClient = Mockito.mock(PredictionServiceClient.class);
+ PredictResponse predictResponse =
+ PredictResponse.newBuilder()
+ .addPredictions(Value.newBuilder().setStringValue(mockedResponse).build())
+ .build();
+ Mockito.when(mockPredictionServiceClient.predict(
+ Mockito.any(EndpointName.class),
+ Mockito.any(List.class),
+ Mockito.any(Value.class)))
+ .thenReturn(predictResponse);
+ }
+
+ @Test
+ public void testGemma2PredictTpu() throws IOException {
+ Gemma2PredictTpu creator = new Gemma2PredictTpu(mockPredictionServiceClient);
+ String response = creator.gemma2PredictTpu(projectId, region, endpointId);
+
+ assertEquals(mockedResponse, response);
+ }
+
+ @Test
+ public void testGemma2PredictGpu() throws IOException {
+ Gemma2PredictGpu creator = new Gemma2PredictGpu(mockPredictionServiceClient);
+ String response = creator.gemma2PredictGpu(projectId, region, endpointId);
+
+ assertEquals(mockedResponse, response);
+ }
+}