From b0e22b0734113fa477b15995d4235446ebaba45e Mon Sep 17 00:00:00 2001 From: Tetiana Yahodska Date: Thu, 19 Sep 2024 10:52:26 +0200 Subject: [PATCH 01/10] Added generativeaionvertexai_gemma2_predict_tpu and generativeaionvertexai_gemma2_predict_gpu sample, created test --- aiplatform/pom.xml | 6 + .../java/aiplatform/Gemma2PredictGpu.java | 87 +++++++++ .../java/aiplatform/Gemma2PredictTpu.java | 88 +++++++++ .../java/aiplatform/Gemma2PredictTest.java | 180 ++++++++++++++++++ .../vertexai/gemini/Gemma2PredictGpu.java | 82 ++++++++ 5 files changed, 443 insertions(+) create mode 100644 aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java create mode 100644 aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java create mode 100644 aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java create mode 100644 vertexai/snippets/src/main/java/vertexai/gemini/Gemma2PredictGpu.java diff --git a/aiplatform/pom.xml b/aiplatform/pom.xml index b53cee5ef2f..2b098e3a364 100644 --- a/aiplatform/pom.xml +++ b/aiplatform/pom.xml @@ -89,5 +89,11 @@ 1.7.1 test + + org.mockito + mockito-core + 5.13.0 + test + diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java new file mode 100644 index 00000000000..d3bba2bd448 --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java @@ -0,0 +1,87 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START generativeaionvertexai_gemma2_predict_gpu] + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class Gemma2PredictGpu { + + public static void main(String[] args) throws IOException { + // TODO(developer): Update & uncomment line below + // String projectId = "your-project-id"; + String projectId = "rsamborski-ai-hypercomputer"; + String region = "us-east4"; + String endpointId = "323876543124209664"; + String parameters = + "{\n" + + " \"temperature\": 0.3,\n" + + " \"maxDecodeSteps\": 200,\n" + + " \"topP\": 0.8,\n" + + " \"topK\": 40\n" + + "}"; + + gemma2PredictGpu(projectId, region, endpointId, parameters); + } + + // Sample to run interference on a Gemma2 model deployed to a Vertex AI endpoint with GPU accellerators. + public static String gemma2PredictGpu(String projectId, String region, String endpointId, String parameters) + throws IOException { + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", region)) + .build(); + // Prompt used in the prediction + String prompt = "Why is the sky blue?"; + + Value.Builder parameterValueBuilder = Value.newBuilder(); + JsonFormat.parser().merge(parameters, parameterValueBuilder); + Value parameterValue = parameterValueBuilder.build(); + + Value promptValue = Value.newBuilder().setStringValue(prompt).build(); + + List instances = new ArrayList<>(); + instances.add(promptValue); + instances.add(parameterValue); + + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings)) { + // Call the Gemma2 endpoint + EndpointName endpointName = EndpointName.of(projectId, region, endpointId); + PredictRequest predictRequest = + PredictRequest.newBuilder() + .setEndpoint(endpointName.toString()) + .addAllInstances(instances) + .build(); + PredictResponse predictResponse = predictionServiceClient.predict(predictRequest); + String textResponse = predictResponse.getPredictions(0).getStringValue(); + System.out.println(textResponse); + return textResponse; + } + } +} +// [END generativeaionvertexai_gemma2_predict_gpu] \ No newline at end of file diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java new file mode 100644 index 00000000000..ce643e754bf --- /dev/null +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java @@ -0,0 +1,88 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START generativeaionvertexai_gemma2_predict_tpu] + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class Gemma2PredictTpu { + + public static void main(String[] args) throws IOException { + // TODO(developer): Update & uncomment line below + // String projectId = "your-project-id"; + String projectId = "rsamborski-ai-hypercomputer"; + String region = "us-west1"; + String endpointId = "9194824316951199744"; + String parameters = + "{\n" + + " \"temperature\": 0.3,\n" + + " \"maxDecodeSteps\": 200,\n" + + " \"topP\": 0.8,\n" + + " \"topK\": 40\n" + + "}"; + + gemma2PredictTpu(projectId, region, endpointId, parameters); + } + + // Sample to run interference on a Gemma2 model deployed to a Vertex AI endpoint with GPU accellerators. + public static String gemma2PredictTpu(String projectId, String region, String endpointId, String parameters) + throws IOException { + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", region)) + .build(); + // Prompt used in the prediction + String prompt = "Why is the sky blue?"; + + Value.Builder parameterValueBuilder = Value.newBuilder(); + JsonFormat.parser().merge(parameters, parameterValueBuilder); + Value parameterValue = parameterValueBuilder.build(); + + Value promptValue = Value.newBuilder().setStringValue(prompt).build(); + + List instances = new ArrayList<>(); + instances.add(promptValue); + instances.add(parameterValue); + + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings)) { + // Call the Gemma2 endpoint + EndpointName endpointName = EndpointName.of(projectId, region, endpointId); + PredictRequest predictRequest = + PredictRequest.newBuilder() + .setEndpoint(endpointName.toString()) + .addAllInstances(instances) + .build(); + PredictResponse predictResponse = predictionServiceClient.predict(predictRequest); + String textResponse = predictResponse.getPredictions(0).getStringValue(); + System.out.println(textResponse); + return textResponse; + } + } +} +// [END generativeaionvertexai_gemma2_predict_tpu] + diff --git a/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java b/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java new file mode 100644 index 00000000000..289ec1817db --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java @@ -0,0 +1,180 @@ +package aiplatform; + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class Gemma2PredictTest { + + // Global variables + private static final String PROJECT_ID = "your-project-id"; + private static final String GPU_ENDPOINT_REGION = "us-east1"; + private static final String GPU_ENDPOINT_ID = "123456789"; // Mock ID used to check if GPU was called + private static final String TPU_ENDPOINT_REGION = "us-west1"; + private static final String TPU_ENDPOINT_ID = "987654321"; // Mock ID used to check if TPU was called + + // MOCKED RESPONSE + private static final String MODEL_RESPONSES = + "The sky appears blue due to a phenomenon called **Rayleigh scattering**.\n" + + "**Here's how it works:**\n" + + "1. **Sunlight:** Sunlight is composed of all the colors of the rainbow.\n" + + "2. **Earth's Atmosphere:** When sunlight enters the Earth's atmosphere, it collides with tiny particles like nitrogen and oxygen molecules.\n" + + "3. **Scattering:** These particles scatter the sunlight in all directions. However, blue light (which has a shorter wavelength) is scattered more effectively than other colors.\n" + + "4. **Our Perception:** As a result, we see a blue sky because the scattered blue light reaches our eyes from all directions.\n" + + "**Why not other colors?**\n" + + "* **Violet light** has an even shorter wavelength than blue and is scattered even more. However, our eyes are less sensitive to violet light, so we perceive the sky as blue.\n" + + "* **Longer wavelengths** like red, orange, and yellow are scattered less and travel more directly through the atmosphere. This is why we see these colors during sunrise and sunset, when sunlight has to travel through more of the atmosphere.\n"; + private PredictResponse mockPredict(String endpoint, List instances) + throws IOException { + String gpuEndpoint = + String.format( + "projects/%s/locations/%s/endpoints/%s", + PROJECT_ID, GPU_ENDPOINT_REGION, GPU_ENDPOINT_ID); + String tpuEndpoint = + String.format( + "projects/%s/locations/%s/endpoints/%s", + PROJECT_ID, TPU_ENDPOINT_REGION, TPU_ENDPOINT_ID); + + Map instanceFields = + instances.get(0).getStructValue().getFieldsMap(); + + if (endpoint.equals(gpuEndpoint)) { + Assert.assertTrue(instanceFields.containsKey("inputs") && instanceFields.get("inputs").hasStringValue()); + } else if (endpoint.equals(tpuEndpoint)) { + // Assertions for TPU format + } else { + Assert.fail("Unexpected endpoint: " + endpoint); + } + + PredictResponse response = + PredictResponse.newBuilder() + .addPredictions(Value.newBuilder().setStringValue(MODEL_RESPONSES).build()) + .build(); + return response; + } + } + + @Test + public void testGemma2PredictGpu() throws IOException { + PredictionServiceClient mockClient = Mockito.mock(PredictionServiceClient.class); + Mockito.when(mockClient.predict(Mockito.any(PredictRequest.class))) + .thenAnswer( + invocation -> { + PredictRequest request = invocation.getArgument(0); + return mockPredict(request.getEndpoint()); + }); + + String response = + gemma2PredictGpu( + mockClient, PROJECT_ID, GPU_ENDPOINT_REGION, GPU_ENDPOINT_ID, "Why is the sky blue?"); + Assert.assertTrue(response.contains("Rayleigh scattering")); + } + + @Test + public void testGemma2PredictTpu() throws IOException { + PredictionServiceClient mockClient = Mockito.mock(PredictionServiceClient.class); + Mockito.when(mockClient.predict(Mockito.any(PredictRequest.class))) + .thenAnswer( + invocation -> { + PredictRequest request = invocation.getArgument(0); + return mockPredict(request.getEndpoint()); + }); + + String response = + gemma2PredictTpu( + mockClient, PROJECT_ID, TPU_ENDPOINT_REGION, TPU_ENDPOINT_ID, "Why is the sky blue?"); + Assert.assertTrue(response.contains("Rayleigh scattering")); + } + + // Implement actual logic for gemma2PredictGpu and gemma2PredictTpu + public static String gemma2PredictGpu( + PredictionServiceClient predictionServiceClient, + String projectId, + String endpointRegion, + String endpointId, + String prompt) + throws IOException { + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion)) + .build(); + + // Default configuration + Map config = new HashMap<>(); + config.put("max_tokens", 1024); + config.put("temperature", 0.9); + config.put("top_p", 1.0); + config.put("top_k", 1); + + // Encapsulate the prompt in a correct format for GPUs + Map input = new HashMap<>(); + input.put("inputs", prompt); + input.put("parameters", config); + + Value.Builder instanceValue = Value.newBuilder(); + JsonFormat.parser().merge(JsonFormat.printer().print(Value.of(input)), instanceValue); + List instances = new ArrayList<>(); + instances.add(instanceValue.build()); + + // Call the Gemma2 endpoint + EndpointName endpointName = EndpointName.of(projectId, endpointRegion, endpointId); + PredictRequest predictRequest = + PredictRequest.newBuilder() + .setEndpoint(endpointName.toString()) + .addAllInstances(instances) + .build(); + PredictResponse predictResponse = predictionServiceClient.predict(predictRequest); + return predictResponse.getPredictions(0).getStringValue(); + } + + public static String gemma2PredictTpu( + PredictionServiceClient predictionServiceClient, + String projectId, + String endpointRegion, + String endpointId, + String prompt) + throws IOException { + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion)) + .build(); + // Prompt used in the prediction + + Map input = new HashMap<>(); + input.put("prompt", prompt); + input.put("max_tokens", 1024); + input.put("temperature", 0.9); + input.put("top_p", 1.0); + input.put("top_k", 1); + + Value.Builder instanceValue = Value.newBuilder(); + JsonFormat.parser().merge(JsonFormat.printer().print(Value.of(input)), instanceValue); + List instances = new ArrayList<>(); + instances.add(instanceValue.build()); + + // Call the Gemma2 endpoint + EndpointName endpointName = EndpointName.of(projectId, endpointRegion, endpointId); + PredictRequest predictRequest = + PredictRequest.newBuilder() + .setEndpoint(endpointName.toString()) + .addAllInstances(instances) + .build(); + PredictResponse predictResponse = predictionServiceClient.predict(predictRequest); + return predictResponse.getPredictions(0).getStringValue(); + } +} \ No newline at end of file diff --git a/vertexai/snippets/src/main/java/vertexai/gemini/Gemma2PredictGpu.java b/vertexai/snippets/src/main/java/vertexai/gemini/Gemma2PredictGpu.java new file mode 100644 index 00000000000..eceb490811c --- /dev/null +++ b/vertexai/snippets/src/main/java/vertexai/gemini/Gemma2PredictGpu.java @@ -0,0 +1,82 @@ +package vertexai.gemini; + +import com.google.api.gax.core.FixedCredentialsProvider; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class Gemma2PredictGpu { + + public static void main(String[] args) throws IOException { + if (args.length != 2) { + System.out.println( + "Usage: java Gemma2PredictGpu "); + System.exit(1); + } + + String endpointRegion = args[0]; + String endpointId = args[1]; + gemma2PredictGpu(endpointRegion, endpointId); + } + + public static String gemma2PredictGpu(String endpointRegion, String endpointId) + throws IOException { + // TODO(developer): Update & uncomment line below + // String projectId = "your-project-id"; + String projectId = System.getenv("GOOGLE_CLOUD_PROJECT"); + + // Default configuration + Map config = new HashMap<>(); + config.put("max_tokens", 1024); + config.put("temperature", 0.9); + config.put("top_p", 1.0); + config.put("top_k", 1); + + // Prompt used in the prediction + String prompt = "Why is the sky blue?"; + + // Encapsulate the prompt in a correct format for GPUs + // Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.9}}] + Map input = new HashMap<>(); + input.put("inputs", prompt); + input.put("parameters", config); + + // Convert input message to a list of GAPIC instances for model input + Value.Builder valueBuilder = Value.newBuilder(); + JsonFormat.parser().merge(JsonFormat.printer().print(input), valueBuilder); + Value instance = valueBuilder.build(); + + // Create a client + GoogleCredentials credentials = GoogleCredentials.getApplicationDefault(); + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setCredentialsProvider(FixedCredentialsProvider.create(credentials)) + .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion)) + .build(); + + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings)) { + // Call the Gemma2 endpoint + String endpointName = + EndpointName.of(projectId, endpointRegion, endpointId).toString(); + PredictRequest predictRequest = + PredictRequest.newBuilder() + .setEndpoint(endpointName) + .addAllInstances(Collections.singletonList(instance)) + .build(); + PredictResponse predictResponse = predictionServiceClient.predict(predictRequest); + String textResponse = predictResponse.getPredictions(0).getStringValue(); + System.out.println(textResponse); + return textResponse; + } + } +} \ No newline at end of file From 10b1bf39991646bb99a08c1b7ec1de4343faa425 Mon Sep 17 00:00:00 2001 From: Tetiana Yahodska Date: Thu, 19 Sep 2024 15:08:31 +0200 Subject: [PATCH 02/10] Fixed instance format --- aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java index d3bba2bd448..0ce1c2c916a 100644 --- a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java @@ -56,16 +56,17 @@ public static String gemma2PredictGpu(String projectId, String region, String en .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", region)) .build(); // Prompt used in the prediction - String prompt = "Why is the sky blue?"; + String instance = "{ \"content\": \"Why is the sky blue?\"}"; Value.Builder parameterValueBuilder = Value.newBuilder(); JsonFormat.parser().merge(parameters, parameterValueBuilder); Value parameterValue = parameterValueBuilder.build(); - Value promptValue = Value.newBuilder().setStringValue(prompt).build(); + Value.Builder instanceValue = Value.newBuilder(); + JsonFormat.parser().merge(instance, instanceValue); List instances = new ArrayList<>(); - instances.add(promptValue); + instances.add(instanceValue.build()); instances.add(parameterValue); try (PredictionServiceClient predictionServiceClient = From 3ac56f9378fd1028ee8b14adccd41f6f97fbd1ee Mon Sep 17 00:00:00 2001 From: Tetiana Yahodska Date: Thu, 19 Sep 2024 16:50:35 +0200 Subject: [PATCH 03/10] Fixed test and instance format for Gemma2PredictTpu --- aiplatform/pom.xml | 6 + .../java/aiplatform/Gemma2PredictGpu.java | 12 +- .../java/aiplatform/Gemma2PredictTpu.java | 17 +-- .../java/aiplatform/Gemma2PredictTest.java | 136 +++++------------- 4 files changed, 49 insertions(+), 122 deletions(-) diff --git a/aiplatform/pom.xml b/aiplatform/pom.xml index 2b098e3a364..14e314a1244 100644 --- a/aiplatform/pom.xml +++ b/aiplatform/pom.xml @@ -95,5 +95,11 @@ 5.13.0 test + + org.junit.jupiter + junit-jupiter + RELEASE + test + diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java index 0ce1c2c916a..e3b2d3b2942 100644 --- a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java @@ -56,7 +56,7 @@ public static String gemma2PredictGpu(String projectId, String region, String en .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", region)) .build(); // Prompt used in the prediction - String instance = "{ \"content\": \"Why is the sky blue?\"}"; + String instance = "{ \"inputs\": \"Why is the sky blue?\"}"; Value.Builder parameterValueBuilder = Value.newBuilder(); JsonFormat.parser().merge(parameters, parameterValueBuilder); @@ -67,18 +67,14 @@ public static String gemma2PredictGpu(String projectId, String region, String en List instances = new ArrayList<>(); instances.add(instanceValue.build()); - instances.add(parameterValue); try (PredictionServiceClient predictionServiceClient = PredictionServiceClient.create(predictionServiceSettings)) { // Call the Gemma2 endpoint EndpointName endpointName = EndpointName.of(projectId, region, endpointId); - PredictRequest predictRequest = - PredictRequest.newBuilder() - .setEndpoint(endpointName.toString()) - .addAllInstances(instances) - .build(); - PredictResponse predictResponse = predictionServiceClient.predict(predictRequest); + + PredictResponse predictResponse = predictionServiceClient + .predict(endpointName, instances, parameterValue); String textResponse = predictResponse.getPredictions(0).getStringValue(); System.out.println(textResponse); return textResponse; diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java index ce643e754bf..0ffd2d3f4c0 100644 --- a/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java @@ -56,28 +56,25 @@ public static String gemma2PredictTpu(String projectId, String region, String en .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", region)) .build(); // Prompt used in the prediction - String prompt = "Why is the sky blue?"; + String instance = "{ \"prompt\": \"Why is the sky blue?\"}"; Value.Builder parameterValueBuilder = Value.newBuilder(); JsonFormat.parser().merge(parameters, parameterValueBuilder); Value parameterValue = parameterValueBuilder.build(); - Value promptValue = Value.newBuilder().setStringValue(prompt).build(); + Value.Builder instanceValue = Value.newBuilder(); + JsonFormat.parser().merge(instance, instanceValue); List instances = new ArrayList<>(); - instances.add(promptValue); - instances.add(parameterValue); + instances.add(instanceValue.build()); try (PredictionServiceClient predictionServiceClient = PredictionServiceClient.create(predictionServiceSettings)) { // Call the Gemma2 endpoint EndpointName endpointName = EndpointName.of(projectId, region, endpointId); - PredictRequest predictRequest = - PredictRequest.newBuilder() - .setEndpoint(endpointName.toString()) - .addAllInstances(instances) - .build(); - PredictResponse predictResponse = predictionServiceClient.predict(predictRequest); + + PredictResponse predictResponse = predictionServiceClient + .predict(endpointName, instances, parameterValue); String textResponse = predictResponse.getPredictions(0).getStringValue(); System.out.println(textResponse); return textResponse; diff --git a/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java b/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java index 289ec1817db..7bc35ba71af 100644 --- a/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java +++ b/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java @@ -1,16 +1,11 @@ package aiplatform; -import com.google.cloud.aiplatform.v1.EndpointName; import com.google.cloud.aiplatform.v1.PredictRequest; import com.google.cloud.aiplatform.v1.PredictResponse; import com.google.cloud.aiplatform.v1.PredictionServiceClient; -import com.google.cloud.aiplatform.v1.PredictionServiceSettings; import com.google.protobuf.Value; -import com.google.protobuf.util.JsonFormat; + import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; import java.util.Map; import org.junit.Assert; import org.junit.Test; @@ -18,15 +13,25 @@ import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; +import static aiplatform.Gemma2PredictGpu.gemma2PredictGpu; +import static aiplatform.Gemma2PredictTpu.gemma2PredictTpu; + @RunWith(MockitoJUnitRunner.class) public class Gemma2PredictTest { // Global variables - private static final String PROJECT_ID = "your-project-id"; + private static final String PROJECT_ID = "rsamborski-ai-hypercomputer"; private static final String GPU_ENDPOINT_REGION = "us-east1"; private static final String GPU_ENDPOINT_ID = "123456789"; // Mock ID used to check if GPU was called private static final String TPU_ENDPOINT_REGION = "us-west1"; private static final String TPU_ENDPOINT_ID = "987654321"; // Mock ID used to check if TPU was called + private static final String PARAMETERS = + "{\n" + + " \"temperature\": 0.3,\n" + + " \"maxDecodeSteps\": 200,\n" + + " \"topP\": 0.8,\n" + + " \"topK\": 40\n" + + "}"; // MOCKED RESPONSE private static final String MODEL_RESPONSES = @@ -39,8 +44,8 @@ public class Gemma2PredictTest { + "**Why not other colors?**\n" + "* **Violet light** has an even shorter wavelength than blue and is scattered even more. However, our eyes are less sensitive to violet light, so we perceive the sky as blue.\n" + "* **Longer wavelengths** like red, orange, and yellow are scattered less and travel more directly through the atmosphere. This is why we see these colors during sunrise and sunset, when sunlight has to travel through more of the atmosphere.\n"; - private PredictResponse mockPredict(String endpoint, List instances) - throws IOException { + + private PredictResponse mockPredict(String endpoint, Value instance) { String gpuEndpoint = String.format( "projects/%s/locations/%s/endpoints/%s", @@ -51,22 +56,21 @@ private PredictResponse mockPredict(String endpoint, List instances) PROJECT_ID, TPU_ENDPOINT_REGION, TPU_ENDPOINT_ID); Map instanceFields = - instances.get(0).getStructValue().getFieldsMap(); - - if (endpoint.equals(gpuEndpoint)) { - Assert.assertTrue(instanceFields.containsKey("inputs") && instanceFields.get("inputs").hasStringValue()); - } else if (endpoint.equals(tpuEndpoint)) { - // Assertions for TPU format - } else { - Assert.fail("Unexpected endpoint: " + endpoint); - } - - PredictResponse response = - PredictResponse.newBuilder() - .addPredictions(Value.newBuilder().setStringValue(MODEL_RESPONSES).build()) - .build(); - return response; + instance.getStructValue().getFieldsMap(); + + if (endpoint.equals(gpuEndpoint)) { + Assert.assertTrue(instanceFields.containsKey("inputs") && instanceFields.get("inputs").hasStringValue()); + } else if (endpoint.equals(tpuEndpoint)) { + Assert.assertTrue(instanceFields.containsKey("prompt") && instanceFields.get("prompt").hasStringValue()); + } else { + Assert.fail("Unexpected endpoint: " + endpoint); } + + PredictResponse response = + PredictResponse.newBuilder() + .addPredictions(Value.newBuilder().setStringValue(MODEL_RESPONSES).build()) + .build(); + return response; } @Test @@ -76,12 +80,12 @@ public void testGemma2PredictGpu() throws IOException { .thenAnswer( invocation -> { PredictRequest request = invocation.getArgument(0); - return mockPredict(request.getEndpoint()); + return mockPredict(request.getEndpoint(), request.getInstances(0)); }); String response = gemma2PredictGpu( - mockClient, PROJECT_ID, GPU_ENDPOINT_REGION, GPU_ENDPOINT_ID, "Why is the sky blue?"); + PROJECT_ID, GPU_ENDPOINT_REGION, GPU_ENDPOINT_ID, PARAMETERS); Assert.assertTrue(response.contains("Rayleigh scattering")); } @@ -92,89 +96,13 @@ public void testGemma2PredictTpu() throws IOException { .thenAnswer( invocation -> { PredictRequest request = invocation.getArgument(0); - return mockPredict(request.getEndpoint()); + return mockPredict(request.getEndpoint(), request.getInstances(0)); }); String response = gemma2PredictTpu( - mockClient, PROJECT_ID, TPU_ENDPOINT_REGION, TPU_ENDPOINT_ID, "Why is the sky blue?"); + PROJECT_ID, TPU_ENDPOINT_REGION, TPU_ENDPOINT_ID, PARAMETERS); Assert.assertTrue(response.contains("Rayleigh scattering")); } - // Implement actual logic for gemma2PredictGpu and gemma2PredictTpu - public static String gemma2PredictGpu( - PredictionServiceClient predictionServiceClient, - String projectId, - String endpointRegion, - String endpointId, - String prompt) - throws IOException { - PredictionServiceSettings predictionServiceSettings = - PredictionServiceSettings.newBuilder() - .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion)) - .build(); - - // Default configuration - Map config = new HashMap<>(); - config.put("max_tokens", 1024); - config.put("temperature", 0.9); - config.put("top_p", 1.0); - config.put("top_k", 1); - - // Encapsulate the prompt in a correct format for GPUs - Map input = new HashMap<>(); - input.put("inputs", prompt); - input.put("parameters", config); - - Value.Builder instanceValue = Value.newBuilder(); - JsonFormat.parser().merge(JsonFormat.printer().print(Value.of(input)), instanceValue); - List instances = new ArrayList<>(); - instances.add(instanceValue.build()); - - // Call the Gemma2 endpoint - EndpointName endpointName = EndpointName.of(projectId, endpointRegion, endpointId); - PredictRequest predictRequest = - PredictRequest.newBuilder() - .setEndpoint(endpointName.toString()) - .addAllInstances(instances) - .build(); - PredictResponse predictResponse = predictionServiceClient.predict(predictRequest); - return predictResponse.getPredictions(0).getStringValue(); - } - - public static String gemma2PredictTpu( - PredictionServiceClient predictionServiceClient, - String projectId, - String endpointRegion, - String endpointId, - String prompt) - throws IOException { - PredictionServiceSettings predictionServiceSettings = - PredictionServiceSettings.newBuilder() - .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion)) - .build(); - // Prompt used in the prediction - - Map input = new HashMap<>(); - input.put("prompt", prompt); - input.put("max_tokens", 1024); - input.put("temperature", 0.9); - input.put("top_p", 1.0); - input.put("top_k", 1); - - Value.Builder instanceValue = Value.newBuilder(); - JsonFormat.parser().merge(JsonFormat.printer().print(Value.of(input)), instanceValue); - List instances = new ArrayList<>(); - instances.add(instanceValue.build()); - - // Call the Gemma2 endpoint - EndpointName endpointName = EndpointName.of(projectId, endpointRegion, endpointId); - PredictRequest predictRequest = - PredictRequest.newBuilder() - .setEndpoint(endpointName.toString()) - .addAllInstances(instances) - .build(); - PredictResponse predictResponse = predictionServiceClient.predict(predictRequest); - return predictResponse.getPredictions(0).getStringValue(); - } } \ No newline at end of file From 8632dbeb7ebada737555f8ecca905a250f347983 Mon Sep 17 00:00:00 2001 From: Tetiana Yahodska Date: Fri, 20 Sep 2024 09:44:47 +0200 Subject: [PATCH 04/10] Deleted class for vertexai package --- .../vertexai/gemini/Gemma2PredictGpu.java | 82 ------------------- 1 file changed, 82 deletions(-) delete mode 100644 vertexai/snippets/src/main/java/vertexai/gemini/Gemma2PredictGpu.java diff --git a/vertexai/snippets/src/main/java/vertexai/gemini/Gemma2PredictGpu.java b/vertexai/snippets/src/main/java/vertexai/gemini/Gemma2PredictGpu.java deleted file mode 100644 index eceb490811c..00000000000 --- a/vertexai/snippets/src/main/java/vertexai/gemini/Gemma2PredictGpu.java +++ /dev/null @@ -1,82 +0,0 @@ -package vertexai.gemini; - -import com.google.api.gax.core.FixedCredentialsProvider; -import com.google.auth.oauth2.GoogleCredentials; -import com.google.cloud.aiplatform.v1.EndpointName; -import com.google.cloud.aiplatform.v1.PredictRequest; -import com.google.cloud.aiplatform.v1.PredictResponse; -import com.google.cloud.aiplatform.v1.PredictionServiceClient; -import com.google.cloud.aiplatform.v1.PredictionServiceSettings; -import com.google.protobuf.Value; -import com.google.protobuf.util.JsonFormat; -import java.io.IOException; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -public class Gemma2PredictGpu { - - public static void main(String[] args) throws IOException { - if (args.length != 2) { - System.out.println( - "Usage: java Gemma2PredictGpu "); - System.exit(1); - } - - String endpointRegion = args[0]; - String endpointId = args[1]; - gemma2PredictGpu(endpointRegion, endpointId); - } - - public static String gemma2PredictGpu(String endpointRegion, String endpointId) - throws IOException { - // TODO(developer): Update & uncomment line below - // String projectId = "your-project-id"; - String projectId = System.getenv("GOOGLE_CLOUD_PROJECT"); - - // Default configuration - Map config = new HashMap<>(); - config.put("max_tokens", 1024); - config.put("temperature", 0.9); - config.put("top_p", 1.0); - config.put("top_k", 1); - - // Prompt used in the prediction - String prompt = "Why is the sky blue?"; - - // Encapsulate the prompt in a correct format for GPUs - // Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.9}}] - Map input = new HashMap<>(); - input.put("inputs", prompt); - input.put("parameters", config); - - // Convert input message to a list of GAPIC instances for model input - Value.Builder valueBuilder = Value.newBuilder(); - JsonFormat.parser().merge(JsonFormat.printer().print(input), valueBuilder); - Value instance = valueBuilder.build(); - - // Create a client - GoogleCredentials credentials = GoogleCredentials.getApplicationDefault(); - PredictionServiceSettings predictionServiceSettings = - PredictionServiceSettings.newBuilder() - .setCredentialsProvider(FixedCredentialsProvider.create(credentials)) - .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion)) - .build(); - - try (PredictionServiceClient predictionServiceClient = - PredictionServiceClient.create(predictionServiceSettings)) { - // Call the Gemma2 endpoint - String endpointName = - EndpointName.of(projectId, endpointRegion, endpointId).toString(); - PredictRequest predictRequest = - PredictRequest.newBuilder() - .setEndpoint(endpointName) - .addAllInstances(Collections.singletonList(instance)) - .build(); - PredictResponse predictResponse = predictionServiceClient.predict(predictRequest); - String textResponse = predictResponse.getPredictions(0).getStringValue(); - System.out.println(textResponse); - return textResponse; - } - } -} \ No newline at end of file From 0c7608f1c677b2e126930ecb9d1a85c78cdb7090 Mon Sep 17 00:00:00 2001 From: Tetiana Yahodska Date: Fri, 20 Sep 2024 14:40:04 +0200 Subject: [PATCH 05/10] Added generativeaionvertexai_gemma2_predict_gpu and generativeaionvertexai_gemma2_predict_tpu samples, created test --- .../java/aiplatform/Gemma2PredictGpu.java | 51 ++++--- .../java/aiplatform/Gemma2PredictTpu.java | 47 +++--- .../java/aiplatform/Gemma2PredictTest.java | 143 +++++++----------- 3 files changed, 111 insertions(+), 130 deletions(-) diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java index e3b2d3b2942..850e7ea36cb 100644 --- a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java @@ -19,7 +19,6 @@ // [START generativeaionvertexai_gemma2_predict_gpu] import com.google.cloud.aiplatform.v1.EndpointName; -import com.google.cloud.aiplatform.v1.PredictRequest; import com.google.cloud.aiplatform.v1.PredictResponse; import com.google.cloud.aiplatform.v1.PredictionServiceClient; import com.google.cloud.aiplatform.v1.PredictionServiceSettings; @@ -31,12 +30,18 @@ public class Gemma2PredictGpu { + private final PredictionServiceClient predictionServiceClient; + + // Constructor to inject the PredictionServiceClient + public Gemma2PredictGpu(PredictionServiceClient predictionServiceClient) { + this.predictionServiceClient = predictionServiceClient; + } + public static void main(String[] args) throws IOException { // TODO(developer): Update & uncomment line below - // String projectId = "your-project-id"; - String projectId = "rsamborski-ai-hypercomputer"; + String projectId = "YOUR_PROJECT_ID"; String region = "us-east4"; - String endpointId = "323876543124209664"; + String endpointId = "YOUR_ENDPOINT_ID"; String parameters = "{\n" + " \"temperature\": 0.3,\n" @@ -44,19 +49,24 @@ public static void main(String[] args) throws IOException { + " \"topP\": 0.8,\n" + " \"topK\": 40\n" + "}"; - - gemma2PredictGpu(projectId, region, endpointId, parameters); - } - - // Sample to run interference on a Gemma2 model deployed to a Vertex AI endpoint with GPU accellerators. - public static String gemma2PredictGpu(String projectId, String region, String endpointId, String parameters) - throws IOException { PredictionServiceSettings predictionServiceSettings = PredictionServiceSettings.newBuilder() .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", region)) .build(); + + PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings); + Gemma2PredictGpu creator = new Gemma2PredictGpu(predictionServiceClient); + + creator.gemma2PredictGpu(projectId, region, endpointId, parameters); + } + + // Demonstrates how to run interference on a Gemma2 model + // deployed to a Vertex AI endpoint with GPU accelerators. + public String gemma2PredictGpu(String projectId, String region, + String endpointId, String parameters) throws IOException { // Prompt used in the prediction - String instance = "{ \"inputs\": \"Why is the sky blue?\"}"; + String instance = "{ \"inputs\": \"Why is the sky blue?\"}"; Value.Builder parameterValueBuilder = Value.newBuilder(); JsonFormat.parser().merge(parameters, parameterValueBuilder); @@ -68,17 +78,14 @@ public static String gemma2PredictGpu(String projectId, String region, String en List instances = new ArrayList<>(); instances.add(instanceValue.build()); - try (PredictionServiceClient predictionServiceClient = - PredictionServiceClient.create(predictionServiceSettings)) { - // Call the Gemma2 endpoint - EndpointName endpointName = EndpointName.of(projectId, region, endpointId); + // Call the Gemma2 endpoint + EndpointName endpointName = EndpointName.of(projectId, region, endpointId); - PredictResponse predictResponse = predictionServiceClient - .predict(endpointName, instances, parameterValue); - String textResponse = predictResponse.getPredictions(0).getStringValue(); - System.out.println(textResponse); - return textResponse; - } + PredictResponse predictResponse = this.predictionServiceClient + .predict(endpointName, instances, parameterValue); + String textResponse = predictResponse.getPredictions(0).getStringValue(); + System.out.println(textResponse); + return textResponse; } } // [END generativeaionvertexai_gemma2_predict_gpu] \ No newline at end of file diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java index 0ffd2d3f4c0..8f541ef8316 100644 --- a/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java @@ -19,7 +19,6 @@ // [START generativeaionvertexai_gemma2_predict_tpu] import com.google.cloud.aiplatform.v1.EndpointName; -import com.google.cloud.aiplatform.v1.PredictRequest; import com.google.cloud.aiplatform.v1.PredictResponse; import com.google.cloud.aiplatform.v1.PredictionServiceClient; import com.google.cloud.aiplatform.v1.PredictionServiceSettings; @@ -30,13 +29,18 @@ import java.util.List; public class Gemma2PredictTpu { + private final PredictionServiceClient predictionServiceClient; + + // Constructor to inject the PredictionServiceClient + public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) { + this.predictionServiceClient = predictionServiceClient; + } public static void main(String[] args) throws IOException { // TODO(developer): Update & uncomment line below - // String projectId = "your-project-id"; - String projectId = "rsamborski-ai-hypercomputer"; + String projectId = "YOUR_PROJECT_ID"; String region = "us-west1"; - String endpointId = "9194824316951199744"; + String endpointId = "YOUR_ENDPOINT_ID"; String parameters = "{\n" + " \"temperature\": 0.3,\n" @@ -44,17 +48,21 @@ public static void main(String[] args) throws IOException { + " \"topP\": 0.8,\n" + " \"topK\": 40\n" + "}"; - - gemma2PredictTpu(projectId, region, endpointId, parameters); - } - - // Sample to run interference on a Gemma2 model deployed to a Vertex AI endpoint with GPU accellerators. - public static String gemma2PredictTpu(String projectId, String region, String endpointId, String parameters) - throws IOException { PredictionServiceSettings predictionServiceSettings = PredictionServiceSettings.newBuilder() .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", region)) .build(); + + PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings); + Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient); + creator.gemma2PredictTpu(projectId, region, endpointId, parameters); + } + + // Demonstrates how to run interference on a Gemma2 model + // deployed to a Vertex AI endpoint with TPU accelerators. + public String gemma2PredictTpu(String projectId, String region, + String endpointId, String parameters) throws IOException { // Prompt used in the prediction String instance = "{ \"prompt\": \"Why is the sky blue?\"}"; @@ -68,17 +76,14 @@ public static String gemma2PredictTpu(String projectId, String region, String en List instances = new ArrayList<>(); instances.add(instanceValue.build()); - try (PredictionServiceClient predictionServiceClient = - PredictionServiceClient.create(predictionServiceSettings)) { - // Call the Gemma2 endpoint - EndpointName endpointName = EndpointName.of(projectId, region, endpointId); + // Call the Gemma2 endpoint + EndpointName endpointName = EndpointName.of(projectId, region, endpointId); - PredictResponse predictResponse = predictionServiceClient - .predict(endpointName, instances, parameterValue); - String textResponse = predictResponse.getPredictions(0).getStringValue(); - System.out.println(textResponse); - return textResponse; - } + PredictResponse predictResponse = this.predictionServiceClient + .predict(endpointName, instances, parameterValue); + String textResponse = predictResponse.getPredictions(0).getStringValue(); + System.out.println(textResponse); + return textResponse; } } // [END generativeaionvertexai_gemma2_predict_tpu] diff --git a/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java b/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java index 7bc35ba71af..70616831e7e 100644 --- a/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java +++ b/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java @@ -1,108 +1,77 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package aiplatform; -import com.google.cloud.aiplatform.v1.PredictRequest; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.google.cloud.aiplatform.v1.EndpointName; import com.google.cloud.aiplatform.v1.PredictResponse; import com.google.cloud.aiplatform.v1.PredictionServiceClient; import com.google.protobuf.Value; - import java.io.IOException; -import java.util.Map; -import org.junit.Assert; -import org.junit.Test; -import org.junit.runner.RunWith; +import java.util.List; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; import org.mockito.Mockito; -import org.mockito.junit.MockitoJUnitRunner; - -import static aiplatform.Gemma2PredictGpu.gemma2PredictGpu; -import static aiplatform.Gemma2PredictTpu.gemma2PredictTpu; -@RunWith(MockitoJUnitRunner.class) public class Gemma2PredictTest { + static String mockedResponse = "The sky appears blue due to a phenomenon " + + "called **Rayleigh scattering**.\n" + + "**Here's how it works:**\n" + + "* **Sunlight is white:** Sunlight actually contains all the colors of the rainbow.\n" + + "* **Scattering:** When sunlight enters the Earth's atmosphere, it collides with tiny gas" + + " molecules (mostly nitrogen and oxygen). These collisions cause the light to scatter " + + "in different directions.\n" + + "* **Blue light scatters most:** Blue light has a shorter wavelength"; + String projectId = "your-project-id"; + String region = "us-central1"; + String endpointId = "your-endpoint-id"; + String parameters = "{}"; + static PredictionServiceClient mockPredictionServiceClient; - // Global variables - private static final String PROJECT_ID = "rsamborski-ai-hypercomputer"; - private static final String GPU_ENDPOINT_REGION = "us-east1"; - private static final String GPU_ENDPOINT_ID = "123456789"; // Mock ID used to check if GPU was called - private static final String TPU_ENDPOINT_REGION = "us-west1"; - private static final String TPU_ENDPOINT_ID = "987654321"; // Mock ID used to check if TPU was called - private static final String PARAMETERS = - "{\n" - + " \"temperature\": 0.3,\n" - + " \"maxDecodeSteps\": 200,\n" - + " \"topP\": 0.8,\n" - + " \"topK\": 40\n" - + "}"; - - // MOCKED RESPONSE - private static final String MODEL_RESPONSES = - "The sky appears blue due to a phenomenon called **Rayleigh scattering**.\n" - + "**Here's how it works:**\n" - + "1. **Sunlight:** Sunlight is composed of all the colors of the rainbow.\n" - + "2. **Earth's Atmosphere:** When sunlight enters the Earth's atmosphere, it collides with tiny particles like nitrogen and oxygen molecules.\n" - + "3. **Scattering:** These particles scatter the sunlight in all directions. However, blue light (which has a shorter wavelength) is scattered more effectively than other colors.\n" - + "4. **Our Perception:** As a result, we see a blue sky because the scattered blue light reaches our eyes from all directions.\n" - + "**Why not other colors?**\n" - + "* **Violet light** has an even shorter wavelength than blue and is scattered even more. However, our eyes are less sensitive to violet light, so we perceive the sky as blue.\n" - + "* **Longer wavelengths** like red, orange, and yellow are scattered less and travel more directly through the atmosphere. This is why we see these colors during sunrise and sunset, when sunlight has to travel through more of the atmosphere.\n"; - - private PredictResponse mockPredict(String endpoint, Value instance) { - String gpuEndpoint = - String.format( - "projects/%s/locations/%s/endpoints/%s", - PROJECT_ID, GPU_ENDPOINT_REGION, GPU_ENDPOINT_ID); - String tpuEndpoint = - String.format( - "projects/%s/locations/%s/endpoints/%s", - PROJECT_ID, TPU_ENDPOINT_REGION, TPU_ENDPOINT_ID); - - Map instanceFields = - instance.getStructValue().getFieldsMap(); - - if (endpoint.equals(gpuEndpoint)) { - Assert.assertTrue(instanceFields.containsKey("inputs") && instanceFields.get("inputs").hasStringValue()); - } else if (endpoint.equals(tpuEndpoint)) { - Assert.assertTrue(instanceFields.containsKey("prompt") && instanceFields.get("prompt").hasStringValue()); - } else { - Assert.fail("Unexpected endpoint: " + endpoint); - } - - PredictResponse response = + @BeforeAll + public static void setUp() { + // Mock PredictionServiceClient and its response + mockPredictionServiceClient = Mockito.mock(PredictionServiceClient.class); + PredictResponse predictResponse = PredictResponse.newBuilder() - .addPredictions(Value.newBuilder().setStringValue(MODEL_RESPONSES).build()) + .addPredictions(Value.newBuilder().setStringValue(mockedResponse).build()) .build(); - return response; + Mockito.when( + mockPredictionServiceClient.predict( + Mockito.any(EndpointName.class), + Mockito.any(List.class), + Mockito.any(Value.class))) + .thenReturn(predictResponse); } @Test - public void testGemma2PredictGpu() throws IOException { - PredictionServiceClient mockClient = Mockito.mock(PredictionServiceClient.class); - Mockito.when(mockClient.predict(Mockito.any(PredictRequest.class))) - .thenAnswer( - invocation -> { - PredictRequest request = invocation.getArgument(0); - return mockPredict(request.getEndpoint(), request.getInstances(0)); - }); + public void testGemma2PredictTpu() throws IOException { + Gemma2PredictTpu creator = new Gemma2PredictTpu(mockPredictionServiceClient); + String response = creator.gemma2PredictTpu(projectId, region, endpointId, parameters); - String response = - gemma2PredictGpu( - PROJECT_ID, GPU_ENDPOINT_REGION, GPU_ENDPOINT_ID, PARAMETERS); - Assert.assertTrue(response.contains("Rayleigh scattering")); + assertEquals(mockedResponse, response); } @Test - public void testGemma2PredictTpu() throws IOException { - PredictionServiceClient mockClient = Mockito.mock(PredictionServiceClient.class); - Mockito.when(mockClient.predict(Mockito.any(PredictRequest.class))) - .thenAnswer( - invocation -> { - PredictRequest request = invocation.getArgument(0); - return mockPredict(request.getEndpoint(), request.getInstances(0)); - }); + public void testGemma2PredictGpu() throws IOException { + Gemma2PredictGpu creator = new Gemma2PredictGpu(mockPredictionServiceClient); + String response = creator.gemma2PredictGpu(projectId, region, endpointId, parameters); - String response = - gemma2PredictTpu( - PROJECT_ID, TPU_ENDPOINT_REGION, TPU_ENDPOINT_ID, PARAMETERS); - Assert.assertTrue(response.contains("Rayleigh scattering")); + assertEquals(mockedResponse, response); } - -} \ No newline at end of file +} From f3538b8c3cdb2cbe7c5bf65e4415126c6a88601b Mon Sep 17 00:00:00 2001 From: Tetiana Yahodska Date: Fri, 20 Sep 2024 14:51:38 +0200 Subject: [PATCH 06/10] Fixed comments --- aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java | 5 ++--- aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java | 6 +++--- aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java | 3 +-- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java index 850e7ea36cb..77be75787ee 100644 --- a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java @@ -38,7 +38,7 @@ public Gemma2PredictGpu(PredictionServiceClient predictionServiceClient) { } public static void main(String[] args) throws IOException { - // TODO(developer): Update & uncomment line below + // TODO(developer): Replace these variables before running the sample. String projectId = "YOUR_PROJECT_ID"; String region = "us-east4"; String endpointId = "YOUR_ENDPOINT_ID"; @@ -49,11 +49,11 @@ public static void main(String[] args) throws IOException { + " \"topP\": 0.8,\n" + " \"topK\": 40\n" + "}"; + PredictionServiceSettings predictionServiceSettings = PredictionServiceSettings.newBuilder() .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", region)) .build(); - PredictionServiceClient predictionServiceClient = PredictionServiceClient.create(predictionServiceSettings); Gemma2PredictGpu creator = new Gemma2PredictGpu(predictionServiceClient); @@ -78,7 +78,6 @@ public String gemma2PredictGpu(String projectId, String region, List instances = new ArrayList<>(); instances.add(instanceValue.build()); - // Call the Gemma2 endpoint EndpointName endpointName = EndpointName.of(projectId, region, endpointId); PredictResponse predictResponse = this.predictionServiceClient diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java index 8f541ef8316..7404c743c6e 100644 --- a/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java @@ -37,7 +37,7 @@ public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) { } public static void main(String[] args) throws IOException { - // TODO(developer): Update & uncomment line below + // TODO(developer): Replace these variables before running the sample. String projectId = "YOUR_PROJECT_ID"; String region = "us-west1"; String endpointId = "YOUR_ENDPOINT_ID"; @@ -48,14 +48,15 @@ public static void main(String[] args) throws IOException { + " \"topP\": 0.8,\n" + " \"topK\": 40\n" + "}"; + PredictionServiceSettings predictionServiceSettings = PredictionServiceSettings.newBuilder() .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", region)) .build(); - PredictionServiceClient predictionServiceClient = PredictionServiceClient.create(predictionServiceSettings); Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient); + creator.gemma2PredictTpu(projectId, region, endpointId, parameters); } @@ -76,7 +77,6 @@ public String gemma2PredictTpu(String projectId, String region, List instances = new ArrayList<>(); instances.add(instanceValue.build()); - // Call the Gemma2 endpoint EndpointName endpointName = EndpointName.of(projectId, region, endpointId); PredictResponse predictResponse = this.predictionServiceClient diff --git a/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java b/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java index 70616831e7e..0cf5a4a8da5 100644 --- a/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java +++ b/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java @@ -51,8 +51,7 @@ public static void setUp() { PredictResponse.newBuilder() .addPredictions(Value.newBuilder().setStringValue(mockedResponse).build()) .build(); - Mockito.when( - mockPredictionServiceClient.predict( + Mockito.when(mockPredictionServiceClient.predict( Mockito.any(EndpointName.class), Mockito.any(List.class), Mockito.any(Value.class))) From fe54ef1186914e3ab890289f47fd4f9ad99a49cc Mon Sep 17 00:00:00 2001 From: Tetiana Yahodska Date: Wed, 25 Sep 2024 21:18:40 +0200 Subject: [PATCH 07/10] added comments, created new test --- .../java/aiplatform/Gemma2PredictGpu.java | 16 +-- .../java/aiplatform/Gemma2PredictTpu.java | 16 +-- .../aiplatform/Gemma2PredictionsTest.java | 109 ++++++++++++++++++ 3 files changed, 127 insertions(+), 14 deletions(-) create mode 100644 aiplatform/src/test/java/aiplatform/Gemma2PredictionsTest.java diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java index 77be75787ee..31f6133ec02 100644 --- a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java @@ -40,25 +40,25 @@ public Gemma2PredictGpu(PredictionServiceClient predictionServiceClient) { public static void main(String[] args) throws IOException { // TODO(developer): Replace these variables before running the sample. String projectId = "YOUR_PROJECT_ID"; - String region = "us-east4"; + String endpointRegion = "us-east4"; String endpointId = "YOUR_ENDPOINT_ID"; String parameters = "{\n" - + " \"temperature\": 0.3,\n" - + " \"maxDecodeSteps\": 200,\n" - + " \"topP\": 0.8,\n" - + " \"topK\": 40\n" + + " \"temperature\": 0.9,\n" + + " \"maxOutputTokens\": 1024,\n" + + " \"topP\": 1.0,\n" + + " \"topK\": 1\n" + "}"; PredictionServiceSettings predictionServiceSettings = PredictionServiceSettings.newBuilder() - .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", region)) + .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion)) .build(); PredictionServiceClient predictionServiceClient = PredictionServiceClient.create(predictionServiceSettings); Gemma2PredictGpu creator = new Gemma2PredictGpu(predictionServiceClient); - creator.gemma2PredictGpu(projectId, region, endpointId, parameters); + creator.gemma2PredictGpu(projectId, endpointRegion, endpointId, parameters); } // Demonstrates how to run interference on a Gemma2 model @@ -75,6 +75,8 @@ public String gemma2PredictGpu(String projectId, String region, Value.Builder instanceValue = Value.newBuilder(); JsonFormat.parser().merge(instance, instanceValue); + // Encapsulate the prompt in a correct format for GPUs + // Example format: [{'prompt': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}] List instances = new ArrayList<>(); instances.add(instanceValue.build()); diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java index 7404c743c6e..518ea8a6408 100644 --- a/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java @@ -39,25 +39,25 @@ public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) { public static void main(String[] args) throws IOException { // TODO(developer): Replace these variables before running the sample. String projectId = "YOUR_PROJECT_ID"; - String region = "us-west1"; + String endpointRegion = "us-west1"; String endpointId = "YOUR_ENDPOINT_ID"; String parameters = "{\n" - + " \"temperature\": 0.3,\n" - + " \"maxDecodeSteps\": 200,\n" - + " \"topP\": 0.8,\n" - + " \"topK\": 40\n" + + " \"temperature\": 0.9,\n" + + " \"maxOutputTokens\": 1024,\n" + + " \"topP\": 1.0,\n" + + " \"topK\": 1\n" + "}"; PredictionServiceSettings predictionServiceSettings = PredictionServiceSettings.newBuilder() - .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", region)) + .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion)) .build(); PredictionServiceClient predictionServiceClient = PredictionServiceClient.create(predictionServiceSettings); Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient); - creator.gemma2PredictTpu(projectId, region, endpointId, parameters); + creator.gemma2PredictTpu(projectId, endpointRegion, endpointId, parameters); } // Demonstrates how to run interference on a Gemma2 model @@ -74,6 +74,8 @@ public String gemma2PredictTpu(String projectId, String region, Value.Builder instanceValue = Value.newBuilder(); JsonFormat.parser().merge(instance, instanceValue); + // Encapsulate the prompt in a correct format for TPUs + // Example format: [{'prompt': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}] List instances = new ArrayList<>(); instances.add(instanceValue.build()); diff --git a/aiplatform/src/test/java/aiplatform/Gemma2PredictionsTest.java b/aiplatform/src/test/java/aiplatform/Gemma2PredictionsTest.java new file mode 100644 index 00000000000..d49fee6c07a --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/Gemma2PredictionsTest.java @@ -0,0 +1,109 @@ +package aiplatform; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.*; + +import com.google.cloud.aiplatform.v1.PredictResponse; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.stub.PredictionServiceStub; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + + +public class Gemma2PredictionsTest { + + private static final String RESPONSE = "The sky appears blue due to a phenomenon called **Rayleigh scattering**.\n" + + "**Here's how it works:**\n" + + "1. **Sunlight:** Sunlight is composed of all the colors of the rainbow.\n" + + "2. **Earth's Atmosphere:** When sunlight enters the Earth's atmosphere, it collides with tiny particles like nitrogen and oxygen molecules.\n" + + "3. **Scattering:** These particles scatter the sunlight in all directions. However, blue light (which has a shorter wavelength) is scattered more effectively than other colors.\n" + + "4. **Our Perception:** As a result, we see a blue sky because the scattered blue light reaches our eyes from all directions.\n" + + "**Why not other colors?**\n" + + "* **Violet light** has an even shorter wavelength than blue and is scattered even more. However, our eyes are less sensitive to violet light, so we perceive the sky as blue.\n" + + "* **Longer wavelengths** like red, orange, and yellow are scattered less and travel more directly through the atmosphere. This is why we see these colors during sunrise and sunset, when sunlight has to travel through more of the atmosphere.\n"; + + private static final String PROJECT_ID = "rsamborski-ai-hypercomputer"; + private static final String GPU_ENDPOINT_REGION = "us-east1"; + private static final String GPU_ENDPOINT_ID = "323876543124209664"; // Mock ID used to check if GPU was called + private static final String TPU_ENDPOINT_REGION = "us-west1"; + private static final String TPU_ENDPOINT_ID = "9194824316951199744"; + private static final String PARAMETERS = + "{\n" + + " \"temperature\": 0.3,\n" + + " \"maxOutputTokens\": 200,\n" + + " \"topP\": 0.8,\n" + + " \"topK\": 40\n" + + "}"; + private final PredictionServiceStub mockStub = Mockito.mock(PredictionServiceStub.class); + PredictionServiceClient client = PredictionServiceClient.create(mockStub); + + @Test + public void testShouldRunInterferenceWithGPU() throws IOException { + PredictionServiceStub mockStub = Mockito.mock(PredictionServiceStub.class); + PredictionServiceClient client = PredictionServiceClient.create(mockStub); + String instance = "{ \"inputs\": \"Why is the sky blue?\"}"; + Value.Builder instanceValue = Value.newBuilder(); + JsonFormat.parser().merge(instance, instanceValue); + + Value.Builder parameterValueBuilder = Value.newBuilder(); + JsonFormat.parser().merge(PARAMETERS, parameterValueBuilder); + Value parameterValue = parameterValueBuilder.build(); + + List instances = new ArrayList<>(); + instances.add(instanceValue.build()); + + PredictResponse mockResponse = PredictResponse.newBuilder() + .addPredictions(Value.newBuilder().setStringValue(RESPONSE).build()) + .build(); + + Mockito.when(client.predict(GPU_ENDPOINT_ID, instances, parameterValue)).thenReturn(mockResponse); + // NullPointerException + Gemma2PredictGpu gemma2PredictGpu = new Gemma2PredictGpu(client); + + String output = gemma2PredictGpu.gemma2PredictGpu(PROJECT_ID, GPU_ENDPOINT_REGION, GPU_ENDPOINT_ID, PARAMETERS); + + assertTrue(output.contains("Rayleigh scattering")); + verify(client, times(1)).predict(GPU_ENDPOINT_ID, instances, parameterValue); + assertTrue(instances.contains("inputs")); + assertTrue(parameterValue.getStructValue().containsFields("temperature")); + assertTrue(parameterValue.getStructValue().containsFields("maxOutputTokens")); + assertTrue(parameterValue.getStructValue().containsFields("topP")); + assertTrue(parameterValue.getStructValue().containsFields("topK")); + } + + @Test + public void testShouldRunInterferenceWithTPU() throws IOException { + + String instance = "{ \"prompt\": \"Why is the sky blue?\"}"; + Value.Builder instanceValue = Value.newBuilder(); + JsonFormat.parser().merge(instance, instanceValue); + + Value.Builder parameterValueBuilder = Value.newBuilder(); + JsonFormat.parser().merge(PARAMETERS, parameterValueBuilder); + Value parameterValue = parameterValueBuilder.build(); + + List instances = new ArrayList<>(); + instances.add(instanceValue.build()); + + PredictResponse mockResponse = PredictResponse.newBuilder() + .addPredictions(Value.newBuilder().setStringValue(RESPONSE).build()) + .build(); + when(client.predict(TPU_ENDPOINT_ID, instances, parameterValue)).thenReturn(mockResponse); + // NullPointerException + Gemma2PredictTpu gemma2PredictTpu = new Gemma2PredictTpu(client); + String output = gemma2PredictTpu.gemma2PredictTpu(PROJECT_ID,TPU_ENDPOINT_REGION, TPU_ENDPOINT_ID, PARAMETERS); + + assertTrue(output.contains("Rayleigh scattering")); + verify(client, times(1)).predict(GPU_ENDPOINT_ID, instances, parameterValue); + assertTrue(instances.contains("prompt")); + assertTrue(parameterValue.getStructValue().containsFields("temperature")); + assertTrue(parameterValue.getStructValue().containsFields("maxOutputTokens")); + assertTrue(parameterValue.getStructValue().containsFields("topP")); + assertTrue(parameterValue.getStructValue().containsFields("topK")); + } +} From 5ddc3610761b21ce0e8be8b585a900d5d3ce1567 Mon Sep 17 00:00:00 2001 From: Tetiana Yahodska Date: Thu, 26 Sep 2024 18:10:29 +0200 Subject: [PATCH 08/10] Fixed parameters, created test to check parameters --- .../java/aiplatform/Gemma2PredictGpu.java | 38 +++--- .../java/aiplatform/Gemma2PredictTpu.java | 37 +++--- .../java/aiplatform/Gemma2ParametersTest.java | 108 +++++++++++++++++ .../java/aiplatform/Gemma2PredictTest.java | 5 +- .../aiplatform/Gemma2PredictionsTest.java | 109 ------------------ 5 files changed, 153 insertions(+), 144 deletions(-) create mode 100644 aiplatform/src/test/java/aiplatform/Gemma2ParametersTest.java delete mode 100644 aiplatform/src/test/java/aiplatform/Gemma2PredictionsTest.java diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java index 31f6133ec02..5893938f99d 100644 --- a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java @@ -22,11 +22,15 @@ import com.google.cloud.aiplatform.v1.PredictResponse; import com.google.cloud.aiplatform.v1.PredictionServiceClient; import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.gson.Gson; +import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; public class Gemma2PredictGpu { @@ -42,13 +46,6 @@ public static void main(String[] args) throws IOException { String projectId = "YOUR_PROJECT_ID"; String endpointRegion = "us-east4"; String endpointId = "YOUR_ENDPOINT_ID"; - String parameters = - "{\n" - + " \"temperature\": 0.9,\n" - + " \"maxOutputTokens\": 1024,\n" - + " \"topP\": 1.0,\n" - + " \"topK\": 1\n" - + "}"; PredictionServiceSettings predictionServiceSettings = PredictionServiceSettings.newBuilder() @@ -58,23 +55,24 @@ public static void main(String[] args) throws IOException { PredictionServiceClient.create(predictionServiceSettings); Gemma2PredictGpu creator = new Gemma2PredictGpu(predictionServiceClient); - creator.gemma2PredictGpu(projectId, endpointRegion, endpointId, parameters); + creator.gemma2PredictGpu(projectId, endpointRegion, endpointId); } // Demonstrates how to run interference on a Gemma2 model // deployed to a Vertex AI endpoint with GPU accelerators. public String gemma2PredictGpu(String projectId, String region, - String endpointId, String parameters) throws IOException { + String endpointId) throws IOException { + Map paramsMap = new HashMap<>(); + paramsMap.put("temperature", 0.9); + paramsMap.put("maxOutputTokens", 1024); + paramsMap.put("topP", 1.0); + paramsMap.put("topK", 1); + Value parameters = mapToValue(paramsMap); + // Prompt used in the prediction String instance = "{ \"inputs\": \"Why is the sky blue?\"}"; - - Value.Builder parameterValueBuilder = Value.newBuilder(); - JsonFormat.parser().merge(parameters, parameterValueBuilder); - Value parameterValue = parameterValueBuilder.build(); - Value.Builder instanceValue = Value.newBuilder(); JsonFormat.parser().merge(instance, instanceValue); - // Encapsulate the prompt in a correct format for GPUs // Example format: [{'prompt': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}] List instances = new ArrayList<>(); @@ -83,10 +81,18 @@ public String gemma2PredictGpu(String projectId, String region, EndpointName endpointName = EndpointName.of(projectId, region, endpointId); PredictResponse predictResponse = this.predictionServiceClient - .predict(endpointName, instances, parameterValue); + .predict(endpointName, instances, parameters); String textResponse = predictResponse.getPredictions(0).getStringValue(); System.out.println(textResponse); return textResponse; } + + private static Value mapToValue(Map map) throws InvalidProtocolBufferException { + Gson gson = new Gson(); + String json = gson.toJson(map); + Value.Builder builder = Value.newBuilder(); + JsonFormat.parser().merge(json, builder); + return builder.build(); + } } // [END generativeaionvertexai_gemma2_predict_gpu] \ No newline at end of file diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java index 518ea8a6408..ab694da6fec 100644 --- a/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java @@ -22,11 +22,15 @@ import com.google.cloud.aiplatform.v1.PredictResponse; import com.google.cloud.aiplatform.v1.PredictionServiceClient; import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.gson.Gson; +import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; public class Gemma2PredictTpu { private final PredictionServiceClient predictionServiceClient; @@ -41,13 +45,6 @@ public static void main(String[] args) throws IOException { String projectId = "YOUR_PROJECT_ID"; String endpointRegion = "us-west1"; String endpointId = "YOUR_ENDPOINT_ID"; - String parameters = - "{\n" - + " \"temperature\": 0.9,\n" - + " \"maxOutputTokens\": 1024,\n" - + " \"topP\": 1.0,\n" - + " \"topK\": 1\n" - + "}"; PredictionServiceSettings predictionServiceSettings = PredictionServiceSettings.newBuilder() @@ -57,23 +54,23 @@ public static void main(String[] args) throws IOException { PredictionServiceClient.create(predictionServiceSettings); Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient); - creator.gemma2PredictTpu(projectId, endpointRegion, endpointId, parameters); + creator.gemma2PredictTpu(projectId, endpointRegion, endpointId); } // Demonstrates how to run interference on a Gemma2 model // deployed to a Vertex AI endpoint with TPU accelerators. public String gemma2PredictTpu(String projectId, String region, - String endpointId, String parameters) throws IOException { + String endpointId) throws IOException { + Map paramsMap = new HashMap<>(); + paramsMap.put("temperature", 0.9); + paramsMap.put("maxOutputTokens", 1024); + paramsMap.put("topP", 1.0); + paramsMap.put("topK", 1); + Value parameters = mapToValue(paramsMap); // Prompt used in the prediction String instance = "{ \"prompt\": \"Why is the sky blue?\"}"; - - Value.Builder parameterValueBuilder = Value.newBuilder(); - JsonFormat.parser().merge(parameters, parameterValueBuilder); - Value parameterValue = parameterValueBuilder.build(); - Value.Builder instanceValue = Value.newBuilder(); JsonFormat.parser().merge(instance, instanceValue); - // Encapsulate the prompt in a correct format for TPUs // Example format: [{'prompt': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}] List instances = new ArrayList<>(); @@ -82,11 +79,19 @@ public String gemma2PredictTpu(String projectId, String region, EndpointName endpointName = EndpointName.of(projectId, region, endpointId); PredictResponse predictResponse = this.predictionServiceClient - .predict(endpointName, instances, parameterValue); + .predict(endpointName, instances, parameters); String textResponse = predictResponse.getPredictions(0).getStringValue(); System.out.println(textResponse); return textResponse; } + + private static Value mapToValue(Map map) throws InvalidProtocolBufferException { + Gson gson = new Gson(); + String json = gson.toJson(map); + Value.Builder builder = Value.newBuilder(); + JsonFormat.parser().merge(json, builder); + return builder.build(); + } } // [END generativeaionvertexai_gemma2_predict_tpu] diff --git a/aiplatform/src/test/java/aiplatform/Gemma2ParametersTest.java b/aiplatform/src/test/java/aiplatform/Gemma2ParametersTest.java new file mode 100644 index 00000000000..300eee49c93 --- /dev/null +++ b/aiplatform/src/test/java/aiplatform/Gemma2ParametersTest.java @@ -0,0 +1,108 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.gson.Gson; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Value; +import com.google.protobuf.util.JsonFormat; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.mockito.stubbing.Answer; + +public class Gemma2ParametersTest { + + static PredictionServiceClient mockGpuPredictionServiceClient; + static PredictionServiceClient mockTpuPredictionServiceClient; + private static final String INSTANCE_GPU = "{ \"inputs\": \"Why is the sky blue?\"}"; + private static final String INSTANCE_TPU = "{ \"prompt\": \"Why is the sky blue?\"}"; + + @Test + public void parametersTest() throws InvalidProtocolBufferException { + // Mock GPU and TPU PredictionServiceClient and its response + mockGpuPredictionServiceClient = Mockito.mock(PredictionServiceClient.class); + mockTpuPredictionServiceClient = Mockito.mock(PredictionServiceClient.class); + + Value.Builder instanceValueGpu = Value.newBuilder(); + JsonFormat.parser().merge(INSTANCE_GPU, instanceValueGpu); + List instancesGpu = new ArrayList<>(); + instancesGpu.add(instanceValueGpu.build()); + + Value.Builder instanceValueTpu = Value.newBuilder(); + JsonFormat.parser().merge(INSTANCE_TPU, instanceValueTpu); + List instancesTpu = new ArrayList<>(); + instancesTpu.add(instanceValueTpu.build()); + + Map paramsMap = new HashMap<>(); + paramsMap.put("temperature", 0.9); + paramsMap.put("maxOutputTokens", 1024); + paramsMap.put("topP", 1.0); + paramsMap.put("topK", 1); + Value parameters = mapToValue(paramsMap); + + Mockito.when(mockGpuPredictionServiceClient.predict( + Mockito.any(EndpointName.class), + Mockito.any(List.class), + Mockito.any(Value.class))) + .thenAnswer(invocation -> + mockGpuResponse(instancesGpu, parameters)); + + Mockito.when(mockTpuPredictionServiceClient.predict( + Mockito.any(EndpointName.class), + Mockito.any(List.class), + Mockito.any(Value.class))) + .thenAnswer(invocation -> + mockTpuResponse(instancesTpu, parameters)); + } + + public static Answer mockGpuResponse(List instances, Value parameter) { + + assertTrue(instances.get(0).getStructValue().getFieldsMap().containsKey("inputs")); + assertTrue(parameter.getStructValue().containsFields("temperature")); + assertTrue(parameter.getStructValue().containsFields("maxOutputTokens")); + assertTrue(parameter.getStructValue().containsFields("topP")); + assertTrue(parameter.getStructValue().containsFields("topK")); + return null; + } + + public static Answer mockTpuResponse(List instances, Value parameter) { + + assertTrue(instances.get(0).getStructValue().getFieldsMap().containsKey("prompt")); + assertTrue(parameter.getStructValue().containsFields("temperature")); + assertTrue(parameter.getStructValue().containsFields("maxOutputTokens")); + assertTrue(parameter.getStructValue().containsFields("topP")); + assertTrue(parameter.getStructValue().containsFields("topK")); + return null; + } + + private static Value mapToValue(Map map) throws InvalidProtocolBufferException { + Gson gson = new Gson(); + String json = gson.toJson(map); + Value.Builder builder = Value.newBuilder(); + JsonFormat.parser().merge(json, builder); + return builder.build(); + } +} + diff --git a/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java b/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java index 0cf5a4a8da5..9a78695a2a4 100644 --- a/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java +++ b/aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java @@ -40,7 +40,6 @@ public class Gemma2PredictTest { String projectId = "your-project-id"; String region = "us-central1"; String endpointId = "your-endpoint-id"; - String parameters = "{}"; static PredictionServiceClient mockPredictionServiceClient; @BeforeAll @@ -61,7 +60,7 @@ public static void setUp() { @Test public void testGemma2PredictTpu() throws IOException { Gemma2PredictTpu creator = new Gemma2PredictTpu(mockPredictionServiceClient); - String response = creator.gemma2PredictTpu(projectId, region, endpointId, parameters); + String response = creator.gemma2PredictTpu(projectId, region, endpointId); assertEquals(mockedResponse, response); } @@ -69,7 +68,7 @@ public void testGemma2PredictTpu() throws IOException { @Test public void testGemma2PredictGpu() throws IOException { Gemma2PredictGpu creator = new Gemma2PredictGpu(mockPredictionServiceClient); - String response = creator.gemma2PredictGpu(projectId, region, endpointId, parameters); + String response = creator.gemma2PredictGpu(projectId, region, endpointId); assertEquals(mockedResponse, response); } diff --git a/aiplatform/src/test/java/aiplatform/Gemma2PredictionsTest.java b/aiplatform/src/test/java/aiplatform/Gemma2PredictionsTest.java deleted file mode 100644 index d49fee6c07a..00000000000 --- a/aiplatform/src/test/java/aiplatform/Gemma2PredictionsTest.java +++ /dev/null @@ -1,109 +0,0 @@ -package aiplatform; - -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.*; - -import com.google.cloud.aiplatform.v1.PredictResponse; -import com.google.cloud.aiplatform.v1.PredictionServiceClient; -import com.google.cloud.aiplatform.v1.stub.PredictionServiceStub; -import com.google.protobuf.Value; -import com.google.protobuf.util.JsonFormat; -import org.junit.jupiter.api.Test; -import org.mockito.Mockito; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - - -public class Gemma2PredictionsTest { - - private static final String RESPONSE = "The sky appears blue due to a phenomenon called **Rayleigh scattering**.\n" + - "**Here's how it works:**\n" + - "1. **Sunlight:** Sunlight is composed of all the colors of the rainbow.\n" + - "2. **Earth's Atmosphere:** When sunlight enters the Earth's atmosphere, it collides with tiny particles like nitrogen and oxygen molecules.\n" + - "3. **Scattering:** These particles scatter the sunlight in all directions. However, blue light (which has a shorter wavelength) is scattered more effectively than other colors.\n" + - "4. **Our Perception:** As a result, we see a blue sky because the scattered blue light reaches our eyes from all directions.\n" + - "**Why not other colors?**\n" + - "* **Violet light** has an even shorter wavelength than blue and is scattered even more. However, our eyes are less sensitive to violet light, so we perceive the sky as blue.\n" + - "* **Longer wavelengths** like red, orange, and yellow are scattered less and travel more directly through the atmosphere. This is why we see these colors during sunrise and sunset, when sunlight has to travel through more of the atmosphere.\n"; - - private static final String PROJECT_ID = "rsamborski-ai-hypercomputer"; - private static final String GPU_ENDPOINT_REGION = "us-east1"; - private static final String GPU_ENDPOINT_ID = "323876543124209664"; // Mock ID used to check if GPU was called - private static final String TPU_ENDPOINT_REGION = "us-west1"; - private static final String TPU_ENDPOINT_ID = "9194824316951199744"; - private static final String PARAMETERS = - "{\n" - + " \"temperature\": 0.3,\n" - + " \"maxOutputTokens\": 200,\n" - + " \"topP\": 0.8,\n" - + " \"topK\": 40\n" - + "}"; - private final PredictionServiceStub mockStub = Mockito.mock(PredictionServiceStub.class); - PredictionServiceClient client = PredictionServiceClient.create(mockStub); - - @Test - public void testShouldRunInterferenceWithGPU() throws IOException { - PredictionServiceStub mockStub = Mockito.mock(PredictionServiceStub.class); - PredictionServiceClient client = PredictionServiceClient.create(mockStub); - String instance = "{ \"inputs\": \"Why is the sky blue?\"}"; - Value.Builder instanceValue = Value.newBuilder(); - JsonFormat.parser().merge(instance, instanceValue); - - Value.Builder parameterValueBuilder = Value.newBuilder(); - JsonFormat.parser().merge(PARAMETERS, parameterValueBuilder); - Value parameterValue = parameterValueBuilder.build(); - - List instances = new ArrayList<>(); - instances.add(instanceValue.build()); - - PredictResponse mockResponse = PredictResponse.newBuilder() - .addPredictions(Value.newBuilder().setStringValue(RESPONSE).build()) - .build(); - - Mockito.when(client.predict(GPU_ENDPOINT_ID, instances, parameterValue)).thenReturn(mockResponse); - // NullPointerException - Gemma2PredictGpu gemma2PredictGpu = new Gemma2PredictGpu(client); - - String output = gemma2PredictGpu.gemma2PredictGpu(PROJECT_ID, GPU_ENDPOINT_REGION, GPU_ENDPOINT_ID, PARAMETERS); - - assertTrue(output.contains("Rayleigh scattering")); - verify(client, times(1)).predict(GPU_ENDPOINT_ID, instances, parameterValue); - assertTrue(instances.contains("inputs")); - assertTrue(parameterValue.getStructValue().containsFields("temperature")); - assertTrue(parameterValue.getStructValue().containsFields("maxOutputTokens")); - assertTrue(parameterValue.getStructValue().containsFields("topP")); - assertTrue(parameterValue.getStructValue().containsFields("topK")); - } - - @Test - public void testShouldRunInterferenceWithTPU() throws IOException { - - String instance = "{ \"prompt\": \"Why is the sky blue?\"}"; - Value.Builder instanceValue = Value.newBuilder(); - JsonFormat.parser().merge(instance, instanceValue); - - Value.Builder parameterValueBuilder = Value.newBuilder(); - JsonFormat.parser().merge(PARAMETERS, parameterValueBuilder); - Value parameterValue = parameterValueBuilder.build(); - - List instances = new ArrayList<>(); - instances.add(instanceValue.build()); - - PredictResponse mockResponse = PredictResponse.newBuilder() - .addPredictions(Value.newBuilder().setStringValue(RESPONSE).build()) - .build(); - when(client.predict(TPU_ENDPOINT_ID, instances, parameterValue)).thenReturn(mockResponse); - // NullPointerException - Gemma2PredictTpu gemma2PredictTpu = new Gemma2PredictTpu(client); - String output = gemma2PredictTpu.gemma2PredictTpu(PROJECT_ID,TPU_ENDPOINT_REGION, TPU_ENDPOINT_ID, PARAMETERS); - - assertTrue(output.contains("Rayleigh scattering")); - verify(client, times(1)).predict(GPU_ENDPOINT_ID, instances, parameterValue); - assertTrue(instances.contains("prompt")); - assertTrue(parameterValue.getStructValue().containsFields("temperature")); - assertTrue(parameterValue.getStructValue().containsFields("maxOutputTokens")); - assertTrue(parameterValue.getStructValue().containsFields("topP")); - assertTrue(parameterValue.getStructValue().containsFields("topK")); - } -} From a2e3f43baf83f76f0d7c0db69ea07464e78ba092 Mon Sep 17 00:00:00 2001 From: Tetiana Yahodska Date: Thu, 26 Sep 2024 18:19:01 +0200 Subject: [PATCH 09/10] Fixed comments --- aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java | 2 +- aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java index 5893938f99d..cae28f59588 100644 --- a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java @@ -74,7 +74,7 @@ public String gemma2PredictGpu(String projectId, String region, Value.Builder instanceValue = Value.newBuilder(); JsonFormat.parser().merge(instance, instanceValue); // Encapsulate the prompt in a correct format for GPUs - // Example format: [{'prompt': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}] + // Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}] List instances = new ArrayList<>(); instances.add(instanceValue.build()); diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java index ab694da6fec..a631104909b 100644 --- a/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java @@ -72,7 +72,7 @@ public String gemma2PredictTpu(String projectId, String region, Value.Builder instanceValue = Value.newBuilder(); JsonFormat.parser().merge(instance, instanceValue); // Encapsulate the prompt in a correct format for TPUs - // Example format: [{'prompt': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}] + // Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}] List instances = new ArrayList<>(); instances.add(instanceValue.build()); From 60f1ed889e7e4ae28e81c5b6e3f67c88e7eec8e6 Mon Sep 17 00:00:00 2001 From: Tetiana Yahodska Date: Thu, 26 Sep 2024 22:57:47 +0200 Subject: [PATCH 10/10] Fixed comments --- aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java | 2 +- aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java index cae28f59588..2c3b6c7dace 100644 --- a/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java @@ -58,7 +58,7 @@ public static void main(String[] args) throws IOException { creator.gemma2PredictGpu(projectId, endpointRegion, endpointId); } - // Demonstrates how to run interference on a Gemma2 model + // Demonstrates how to run inference on a Gemma2 model // deployed to a Vertex AI endpoint with GPU accelerators. public String gemma2PredictGpu(String projectId, String region, String endpointId) throws IOException { diff --git a/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java index a631104909b..de29b1cc111 100644 --- a/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java +++ b/aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java @@ -57,7 +57,7 @@ public static void main(String[] args) throws IOException { creator.gemma2PredictTpu(projectId, endpointRegion, endpointId); } - // Demonstrates how to run interference on a Gemma2 model + // Demonstrates how to run inference on a Gemma2 model // deployed to a Vertex AI endpoint with TPU accelerators. public String gemma2PredictTpu(String projectId, String region, String endpointId) throws IOException {