Skip to content
Merged
6 changes: 6 additions & 0 deletions aiplatform/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@
<version>1.7.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>5.13.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
Expand Down
98 changes: 98 additions & 0 deletions aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package aiplatform;

// [START generativeaionvertexai_gemma2_predict_gpu]

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.gson.Gson;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Gemma2PredictGpu {

private final PredictionServiceClient predictionServiceClient;

// Constructor to inject the PredictionServiceClient
public Gemma2PredictGpu(PredictionServiceClient predictionServiceClient) {
this.predictionServiceClient = predictionServiceClient;
}

public static void main(String[] args) throws IOException {
// TODO(developer): Replace these variables before running the sample.
String projectId = "YOUR_PROJECT_ID";
String endpointRegion = "us-east4";
String endpointId = "YOUR_ENDPOINT_ID";

PredictionServiceSettings predictionServiceSettings =
PredictionServiceSettings.newBuilder()
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
.build();
PredictionServiceClient predictionServiceClient =
PredictionServiceClient.create(predictionServiceSettings);
Gemma2PredictGpu creator = new Gemma2PredictGpu(predictionServiceClient);

creator.gemma2PredictGpu(projectId, endpointRegion, endpointId);
}

// Demonstrates how to run interference on a Gemma2 model
// deployed to a Vertex AI endpoint with GPU accelerators.
public String gemma2PredictGpu(String projectId, String region,
String endpointId) throws IOException {
Map<String, Object> paramsMap = new HashMap<>();
paramsMap.put("temperature", 0.9);
paramsMap.put("maxOutputTokens", 1024);
paramsMap.put("topP", 1.0);
paramsMap.put("topK", 1);
Value parameters = mapToValue(paramsMap);

// Prompt used in the prediction
String instance = "{ \"inputs\": \"Why is the sky blue?\"}";
Value.Builder instanceValue = Value.newBuilder();
JsonFormat.parser().merge(instance, instanceValue);
// Encapsulate the prompt in a correct format for GPUs
// Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}]
List<Value> instances = new ArrayList<>();
instances.add(instanceValue.build());

EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

PredictResponse predictResponse = this.predictionServiceClient
.predict(endpointName, instances, parameters);
String textResponse = predictResponse.getPredictions(0).getStringValue();
System.out.println(textResponse);
return textResponse;
}

private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
Gson gson = new Gson();
String json = gson.toJson(map);
Value.Builder builder = Value.newBuilder();
JsonFormat.parser().merge(json, builder);
return builder.build();
}
}
// [END generativeaionvertexai_gemma2_predict_gpu]
97 changes: 97 additions & 0 deletions aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package aiplatform;

// [START generativeaionvertexai_gemma2_predict_tpu]

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.gson.Gson;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Gemma2PredictTpu {
private final PredictionServiceClient predictionServiceClient;

// Constructor to inject the PredictionServiceClient
public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
this.predictionServiceClient = predictionServiceClient;
}

public static void main(String[] args) throws IOException {
// TODO(developer): Replace these variables before running the sample.
String projectId = "YOUR_PROJECT_ID";
String endpointRegion = "us-west1";
String endpointId = "YOUR_ENDPOINT_ID";

PredictionServiceSettings predictionServiceSettings =
PredictionServiceSettings.newBuilder()
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
.build();
PredictionServiceClient predictionServiceClient =
PredictionServiceClient.create(predictionServiceSettings);
Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);

creator.gemma2PredictTpu(projectId, endpointRegion, endpointId);
}

// Demonstrates how to run interference on a Gemma2 model
// deployed to a Vertex AI endpoint with TPU accelerators.
public String gemma2PredictTpu(String projectId, String region,
String endpointId) throws IOException {
Map<String, Object> paramsMap = new HashMap<>();
paramsMap.put("temperature", 0.9);
paramsMap.put("maxOutputTokens", 1024);
paramsMap.put("topP", 1.0);
paramsMap.put("topK", 1);
Value parameters = mapToValue(paramsMap);
// Prompt used in the prediction
String instance = "{ \"prompt\": \"Why is the sky blue?\"}";
Value.Builder instanceValue = Value.newBuilder();
JsonFormat.parser().merge(instance, instanceValue);
// Encapsulate the prompt in a correct format for TPUs
// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
List<Value> instances = new ArrayList<>();
instances.add(instanceValue.build());

EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

PredictResponse predictResponse = this.predictionServiceClient
.predict(endpointName, instances, parameters);
String textResponse = predictResponse.getPredictions(0).getStringValue();
System.out.println(textResponse);
return textResponse;
}

private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
Gson gson = new Gson();
String json = gson.toJson(map);
Value.Builder builder = Value.newBuilder();
JsonFormat.parser().merge(json, builder);
return builder.build();
}
}
// [END generativeaionvertexai_gemma2_predict_tpu]

108 changes: 108 additions & 0 deletions aiplatform/src/test/java/aiplatform/Gemma2ParametersTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package aiplatform;

import static org.junit.jupiter.api.Assertions.assertTrue;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.gson.Gson;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.mockito.stubbing.Answer;

public class Gemma2ParametersTest {

static PredictionServiceClient mockGpuPredictionServiceClient;
static PredictionServiceClient mockTpuPredictionServiceClient;
private static final String INSTANCE_GPU = "{ \"inputs\": \"Why is the sky blue?\"}";
private static final String INSTANCE_TPU = "{ \"prompt\": \"Why is the sky blue?\"}";

@Test
public void parametersTest() throws InvalidProtocolBufferException {
// Mock GPU and TPU PredictionServiceClient and its response
mockGpuPredictionServiceClient = Mockito.mock(PredictionServiceClient.class);
mockTpuPredictionServiceClient = Mockito.mock(PredictionServiceClient.class);

Value.Builder instanceValueGpu = Value.newBuilder();
JsonFormat.parser().merge(INSTANCE_GPU, instanceValueGpu);
List<Value> instancesGpu = new ArrayList<>();
instancesGpu.add(instanceValueGpu.build());

Value.Builder instanceValueTpu = Value.newBuilder();
JsonFormat.parser().merge(INSTANCE_TPU, instanceValueTpu);
List<Value> instancesTpu = new ArrayList<>();
instancesTpu.add(instanceValueTpu.build());

Map<String, Object> paramsMap = new HashMap<>();
paramsMap.put("temperature", 0.9);
paramsMap.put("maxOutputTokens", 1024);
paramsMap.put("topP", 1.0);
paramsMap.put("topK", 1);
Value parameters = mapToValue(paramsMap);

Mockito.when(mockGpuPredictionServiceClient.predict(
Mockito.any(EndpointName.class),
Mockito.any(List.class),
Mockito.any(Value.class)))
.thenAnswer(invocation ->
mockGpuResponse(instancesGpu, parameters));

Mockito.when(mockTpuPredictionServiceClient.predict(
Mockito.any(EndpointName.class),
Mockito.any(List.class),
Mockito.any(Value.class)))
.thenAnswer(invocation ->
mockTpuResponse(instancesTpu, parameters));
}

public static Answer<?> mockGpuResponse(List<Value> instances, Value parameter) {

assertTrue(instances.get(0).getStructValue().getFieldsMap().containsKey("inputs"));
assertTrue(parameter.getStructValue().containsFields("temperature"));
assertTrue(parameter.getStructValue().containsFields("maxOutputTokens"));
assertTrue(parameter.getStructValue().containsFields("topP"));
assertTrue(parameter.getStructValue().containsFields("topK"));
return null;
}

public static Answer<?> mockTpuResponse(List<Value> instances, Value parameter) {

assertTrue(instances.get(0).getStructValue().getFieldsMap().containsKey("prompt"));
assertTrue(parameter.getStructValue().containsFields("temperature"));
assertTrue(parameter.getStructValue().containsFields("maxOutputTokens"));
assertTrue(parameter.getStructValue().containsFields("topP"));
assertTrue(parameter.getStructValue().containsFields("topK"));
return null;
}

private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
Gson gson = new Gson();
String json = gson.toJson(map);
Value.Builder builder = Value.newBuilder();
JsonFormat.parser().merge(json, builder);
return builder.build();
}
}

75 changes: 75 additions & 0 deletions aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package aiplatform;

import static org.junit.jupiter.api.Assertions.assertEquals;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.protobuf.Value;
import java.io.IOException;
import java.util.List;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

public class Gemma2PredictTest {
static String mockedResponse = "The sky appears blue due to a phenomenon "
+ "called **Rayleigh scattering**.\n"
+ "**Here's how it works:**\n"
+ "* **Sunlight is white:** Sunlight actually contains all the colors of the rainbow.\n"
+ "* **Scattering:** When sunlight enters the Earth's atmosphere, it collides with tiny gas"
+ " molecules (mostly nitrogen and oxygen). These collisions cause the light to scatter "
+ "in different directions.\n"
+ "* **Blue light scatters most:** Blue light has a shorter wavelength";
String projectId = "your-project-id";
String region = "us-central1";
String endpointId = "your-endpoint-id";
static PredictionServiceClient mockPredictionServiceClient;

@BeforeAll
public static void setUp() {
// Mock PredictionServiceClient and its response
mockPredictionServiceClient = Mockito.mock(PredictionServiceClient.class);
PredictResponse predictResponse =
PredictResponse.newBuilder()
.addPredictions(Value.newBuilder().setStringValue(mockedResponse).build())
.build();
Mockito.when(mockPredictionServiceClient.predict(
Mockito.any(EndpointName.class),
Mockito.any(List.class),
Mockito.any(Value.class)))
.thenReturn(predictResponse);
}

@Test
public void testGemma2PredictTpu() throws IOException {
Gemma2PredictTpu creator = new Gemma2PredictTpu(mockPredictionServiceClient);
String response = creator.gemma2PredictTpu(projectId, region, endpointId);

assertEquals(mockedResponse, response);
}

@Test
public void testGemma2PredictGpu() throws IOException {
Gemma2PredictGpu creator = new Gemma2PredictGpu(mockPredictionServiceClient);
String response = creator.gemma2PredictGpu(projectId, region, endpointId);

assertEquals(mockedResponse, response);
}
}