diff --git a/aiplatform/pom.xml b/aiplatform/pom.xml index 25f989557bf..14e314a1244 100644 --- a/aiplatform/pom.xml +++ b/aiplatform/pom.xml @@ -89,6 +89,12 @@ 1.7.1 test + + org.mockito + mockito-core + 5.13.0 + test + org.junit.jupiter junit-jupiter diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java new file mode 100644 index 00000000000..2c3b6c7dace --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java @@ -0,0 +1,98 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START generativeaionvertexai_gemma2_predict_gpu] + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.gson.Gson; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class Gemma2PredictGpu { + + private final PredictionServiceClient predictionServiceClient; + + // Constructor to inject the PredictionServiceClient + public Gemma2PredictGpu(PredictionServiceClient predictionServiceClient) { + this.predictionServiceClient = predictionServiceClient; + } + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String projectId = "YOUR_PROJECT_ID"; + String endpointRegion = "us-east4"; + String endpointId = "YOUR_ENDPOINT_ID"; + + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion)) + .build(); + PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings); + Gemma2PredictGpu creator = new Gemma2PredictGpu(predictionServiceClient); + + creator.gemma2PredictGpu(projectId, endpointRegion, endpointId); + } + + // Demonstrates how to run inference on a Gemma2 model + // deployed to a Vertex AI endpoint with GPU accelerators. + public String gemma2PredictGpu(String projectId, String region, + String endpointId) throws IOException { + Map paramsMap = new HashMap<>(); + paramsMap.put("temperature", 0.9); + paramsMap.put("maxOutputTokens", 1024); + paramsMap.put("topP", 1.0); + paramsMap.put("topK", 1); + Value parameters = mapToValue(paramsMap); + + // Prompt used in the prediction + String instance = "{ \"inputs\": \"Why is the sky blue?\"}"; + Value.Builder instanceValue = Value.newBuilder(); + JsonFormat.parser().merge(instance, instanceValue); + // Encapsulate the prompt in a correct format for GPUs + // Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}] + List instances = new ArrayList<>(); + instances.add(instanceValue.build()); + + EndpointName endpointName = EndpointName.of(projectId, region, endpointId); + + PredictResponse predictResponse = this.predictionServiceClient + .predict(endpointName, instances, parameters); + String textResponse = predictResponse.getPredictions(0).getStringValue(); + System.out.println(textResponse); + return textResponse; + } + + private static Value mapToValue(Map map) throws InvalidProtocolBufferException { + Gson gson = new Gson(); + String json = gson.toJson(map); + Value.Builder builder = Value.newBuilder(); + JsonFormat.parser().merge(json, builder); + return builder.build(); + } +} +// [END generativeaionvertexai_gemma2_predict_gpu] \ No newline at end of file diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java new file mode 100644 index 00000000000..de29b1cc111 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java @@ -0,0 +1,97 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START generativeaionvertexai_gemma2_predict_tpu] + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.gson.Gson; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class Gemma2PredictTpu { + private final PredictionServiceClient predictionServiceClient; + + // Constructor to inject the PredictionServiceClient + public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) { + this.predictionServiceClient = predictionServiceClient; + } + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String projectId = "YOUR_PROJECT_ID"; + String endpointRegion = "us-west1"; + String endpointId = "YOUR_ENDPOINT_ID"; + + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion)) + .build(); + PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings); + Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient); + + creator.gemma2PredictTpu(projectId, endpointRegion, endpointId); + } + + // Demonstrates how to run inference on a Gemma2 model + // deployed to a Vertex AI endpoint with TPU accelerators. + public String gemma2PredictTpu(String projectId, String region, + String endpointId) throws IOException { + Map paramsMap = new HashMap<>(); + paramsMap.put("temperature", 0.9); + paramsMap.put("maxOutputTokens", 1024); + paramsMap.put("topP", 1.0); + paramsMap.put("topK", 1); + Value parameters = mapToValue(paramsMap); + // Prompt used in the prediction + String instance = "{ \"prompt\": \"Why is the sky blue?\"}"; + Value.Builder instanceValue = Value.newBuilder(); + JsonFormat.parser().merge(instance, instanceValue); + // Encapsulate the prompt in a correct format for TPUs + // Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}] + List instances = new ArrayList<>(); + instances.add(instanceValue.build()); + + EndpointName endpointName = EndpointName.of(projectId, region, endpointId); + + PredictResponse predictResponse = this.predictionServiceClient + .predict(endpointName, instances, parameters); + String textResponse = predictResponse.getPredictions(0).getStringValue(); + System.out.println(textResponse); + return textResponse; + } + + private static Value mapToValue(Map map) throws InvalidProtocolBufferException { + Gson gson = new Gson(); + String json = gson.toJson(map); + Value.Builder builder = Value.newBuilder(); + JsonFormat.parser().merge(json, builder); + return builder.build(); + } +} +// [END generativeaionvertexai_gemma2_predict_tpu] + diff --git a/aiplatform/src/test/java/aiplatform/Gemma2ParametersTest.java b/aiplatform/src/test/java/aiplatform/Gemma2ParametersTest.java new file mode 100644 index 00000000000..300eee49c93 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/Gemma2ParametersTest.java @@ -0,0 +1,108 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.gson.Gson; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.mockito.stubbing.Answer; + +public class Gemma2ParametersTest { + + static PredictionServiceClient mockGpuPredictionServiceClient; + static PredictionServiceClient mockTpuPredictionServiceClient; + private static final String INSTANCE_GPU = "{ \"inputs\": \"Why is the sky blue?\"}"; + private static final String INSTANCE_TPU = "{ \"prompt\": \"Why is the sky blue?\"}"; + + @Test + public void parametersTest() throws InvalidProtocolBufferException { + // Mock GPU and TPU PredictionServiceClient and its response + mockGpuPredictionServiceClient = Mockito.mock(PredictionServiceClient.class); + mockTpuPredictionServiceClient = Mockito.mock(PredictionServiceClient.class); + + Value.Builder instanceValueGpu = Value.newBuilder(); + JsonFormat.parser().merge(INSTANCE_GPU, instanceValueGpu); + List instancesGpu = new ArrayList<>(); + instancesGpu.add(instanceValueGpu.build()); + + Value.Builder instanceValueTpu = Value.newBuilder(); + JsonFormat.parser().merge(INSTANCE_TPU, instanceValueTpu); + List instancesTpu = new ArrayList<>(); + instancesTpu.add(instanceValueTpu.build()); + + Map paramsMap = new HashMap<>(); + paramsMap.put("temperature", 0.9); + paramsMap.put("maxOutputTokens", 1024); + paramsMap.put("topP", 1.0); + paramsMap.put("topK", 1); + Value parameters = mapToValue(paramsMap); + + Mockito.when(mockGpuPredictionServiceClient.predict( + Mockito.any(EndpointName.class), + Mockito.any(List.class), + Mockito.any(Value.class))) + .thenAnswer(invocation -> + mockGpuResponse(instancesGpu, parameters)); + + Mockito.when(mockTpuPredictionServiceClient.predict( + Mockito.any(EndpointName.class), + Mockito.any(List.class), + Mockito.any(Value.class))) + .thenAnswer(invocation -> + mockTpuResponse(instancesTpu, parameters)); + } + + public static Answer mockGpuResponse(List instances, Value parameter) { + + assertTrue(instances.get(0).getStructValue().getFieldsMap().containsKey("inputs")); + assertTrue(parameter.getStructValue().containsFields("temperature")); + assertTrue(parameter.getStructValue().containsFields("maxOutputTokens")); + assertTrue(parameter.getStructValue().containsFields("topP")); + assertTrue(parameter.getStructValue().containsFields("topK")); + return null; + } + + public static Answer mockTpuResponse(List instances, Value parameter) { + + assertTrue(instances.get(0).getStructValue().getFieldsMap().containsKey("prompt")); + assertTrue(parameter.getStructValue().containsFields("temperature")); + assertTrue(parameter.getStructValue().containsFields("maxOutputTokens")); + assertTrue(parameter.getStructValue().containsFields("topP")); + assertTrue(parameter.getStructValue().containsFields("topK")); + return null; + } + + private static Value mapToValue(Map map) throws InvalidProtocolBufferException { + Gson gson = new Gson(); + String json = gson.toJson(map); + Value.Builder builder = Value.newBuilder(); + JsonFormat.parser().merge(json, builder); + return builder.build(); + } +} + diff --git a/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java b/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java new file mode 100644 index 00000000000..9a78695a2a4 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java @@ -0,0 +1,75 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.protobuf.Value; +import java.io.IOException; +import java.util.List; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +public class Gemma2PredictTest { + static String mockedResponse = "The sky appears blue due to a phenomenon " + + "called **Rayleigh scattering**.\n" + + "**Here's how it works:**\n" + + "* **Sunlight is white:** Sunlight actually contains all the colors of the rainbow.\n" + + "* **Scattering:** When sunlight enters the Earth's atmosphere, it collides with tiny gas" + + " molecules (mostly nitrogen and oxygen). These collisions cause the light to scatter " + + "in different directions.\n" + + "* **Blue light scatters most:** Blue light has a shorter wavelength"; + String projectId = "your-project-id"; + String region = "us-central1"; + String endpointId = "your-endpoint-id"; + static PredictionServiceClient mockPredictionServiceClient; + + @BeforeAll + public static void setUp() { + // Mock PredictionServiceClient and its response + mockPredictionServiceClient = Mockito.mock(PredictionServiceClient.class); + PredictResponse predictResponse = + PredictResponse.newBuilder() + .addPredictions(Value.newBuilder().setStringValue(mockedResponse).build()) + .build(); + Mockito.when(mockPredictionServiceClient.predict( + Mockito.any(EndpointName.class), + Mockito.any(List.class), + Mockito.any(Value.class))) + .thenReturn(predictResponse); + } + + @Test + public void testGemma2PredictTpu() throws IOException { + Gemma2PredictTpu creator = new Gemma2PredictTpu(mockPredictionServiceClient); + String response = creator.gemma2PredictTpu(projectId, region, endpointId); + + assertEquals(mockedResponse, response); + } + + @Test + public void testGemma2PredictGpu() throws IOException { + Gemma2PredictGpu creator = new Gemma2PredictGpu(mockPredictionServiceClient); + String response = creator.gemma2PredictGpu(projectId, region, endpointId); + + assertEquals(mockedResponse, response); + } +}