Skip to content
Merged
6 changes: 6 additions & 0 deletions aiplatform/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@
<version>1.7.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>5.13.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
Expand Down
92 changes: 92 additions & 0 deletions aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package aiplatform;

// [START generativeaionvertexai_gemma2_predict_gpu]

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class Gemma2PredictGpu {

private final PredictionServiceClient predictionServiceClient;

// Constructor to inject the PredictionServiceClient
public Gemma2PredictGpu(PredictionServiceClient predictionServiceClient) {
this.predictionServiceClient = predictionServiceClient;
}

public static void main(String[] args) throws IOException {
// TODO(developer): Replace these variables before running the sample.
String projectId = "YOUR_PROJECT_ID";
String endpointRegion = "us-east4";
String endpointId = "YOUR_ENDPOINT_ID";
String parameters =
"{\n"
+ " \"temperature\": 0.9,\n"
+ " \"maxOutputTokens\": 1024,\n"
+ " \"topP\": 1.0,\n"
+ " \"topK\": 1\n"
+ "}";

PredictionServiceSettings predictionServiceSettings =
PredictionServiceSettings.newBuilder()
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
.build();
PredictionServiceClient predictionServiceClient =
PredictionServiceClient.create(predictionServiceSettings);
Gemma2PredictGpu creator = new Gemma2PredictGpu(predictionServiceClient);

creator.gemma2PredictGpu(projectId, endpointRegion, endpointId, parameters);
}

// Demonstrates how to run interference on a Gemma2 model
// deployed to a Vertex AI endpoint with GPU accelerators.
public String gemma2PredictGpu(String projectId, String region,
String endpointId, String parameters) throws IOException {
// Prompt used in the prediction
String instance = "{ \"inputs\": \"Why is the sky blue?\"}";

Value.Builder parameterValueBuilder = Value.newBuilder();
JsonFormat.parser().merge(parameters, parameterValueBuilder);
Value parameterValue = parameterValueBuilder.build();

Value.Builder instanceValue = Value.newBuilder();
JsonFormat.parser().merge(instance, instanceValue);

// Encapsulate the prompt in a correct format for GPUs
// Example format: [{'prompt': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}]
List<Value> instances = new ArrayList<>();
instances.add(instanceValue.build());

EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

PredictResponse predictResponse = this.predictionServiceClient
.predict(endpointName, instances, parameterValue);
String textResponse = predictResponse.getPredictions(0).getStringValue();
System.out.println(textResponse);
return textResponse;
}
}
// [END generativeaionvertexai_gemma2_predict_gpu]
92 changes: 92 additions & 0 deletions aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package aiplatform;

// [START generativeaionvertexai_gemma2_predict_tpu]

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class Gemma2PredictTpu {
private final PredictionServiceClient predictionServiceClient;

// Constructor to inject the PredictionServiceClient
public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
this.predictionServiceClient = predictionServiceClient;
}

public static void main(String[] args) throws IOException {
// TODO(developer): Replace these variables before running the sample.
String projectId = "YOUR_PROJECT_ID";
String endpointRegion = "us-west1";
String endpointId = "YOUR_ENDPOINT_ID";
String parameters =
"{\n"
+ " \"temperature\": 0.9,\n"
+ " \"maxOutputTokens\": 1024,\n"
+ " \"topP\": 1.0,\n"
+ " \"topK\": 1\n"
+ "}";

PredictionServiceSettings predictionServiceSettings =
PredictionServiceSettings.newBuilder()
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
.build();
PredictionServiceClient predictionServiceClient =
PredictionServiceClient.create(predictionServiceSettings);
Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);

creator.gemma2PredictTpu(projectId, endpointRegion, endpointId, parameters);
}

// Demonstrates how to run interference on a Gemma2 model
// deployed to a Vertex AI endpoint with TPU accelerators.
public String gemma2PredictTpu(String projectId, String region,
String endpointId, String parameters) throws IOException {
// Prompt used in the prediction
String instance = "{ \"prompt\": \"Why is the sky blue?\"}";

Value.Builder parameterValueBuilder = Value.newBuilder();
JsonFormat.parser().merge(parameters, parameterValueBuilder);
Value parameterValue = parameterValueBuilder.build();

Value.Builder instanceValue = Value.newBuilder();
JsonFormat.parser().merge(instance, instanceValue);

// Encapsulate the prompt in a correct format for TPUs
// Example format: [{'prompt': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}]
List<Value> instances = new ArrayList<>();
instances.add(instanceValue.build());

EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

PredictResponse predictResponse = this.predictionServiceClient
.predict(endpointName, instances, parameterValue);
String textResponse = predictResponse.getPredictions(0).getStringValue();
System.out.println(textResponse);
return textResponse;
}
}
// [END generativeaionvertexai_gemma2_predict_tpu]

76 changes: 76 additions & 0 deletions aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package aiplatform;

import static org.junit.jupiter.api.Assertions.assertEquals;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.protobuf.Value;
import java.io.IOException;
import java.util.List;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

public class Gemma2PredictTest {
static String mockedResponse = "The sky appears blue due to a phenomenon "
+ "called **Rayleigh scattering**.\n"
+ "**Here's how it works:**\n"
+ "* **Sunlight is white:** Sunlight actually contains all the colors of the rainbow.\n"
+ "* **Scattering:** When sunlight enters the Earth's atmosphere, it collides with tiny gas"
+ " molecules (mostly nitrogen and oxygen). These collisions cause the light to scatter "
+ "in different directions.\n"
+ "* **Blue light scatters most:** Blue light has a shorter wavelength";
String projectId = "your-project-id";
String region = "us-central1";
String endpointId = "your-endpoint-id";
String parameters = "{}";
static PredictionServiceClient mockPredictionServiceClient;

@BeforeAll
public static void setUp() {
// Mock PredictionServiceClient and its response
mockPredictionServiceClient = Mockito.mock(PredictionServiceClient.class);
PredictResponse predictResponse =
PredictResponse.newBuilder()
.addPredictions(Value.newBuilder().setStringValue(mockedResponse).build())
.build();
Mockito.when(mockPredictionServiceClient.predict(
Mockito.any(EndpointName.class),
Mockito.any(List.class),
Mockito.any(Value.class)))
.thenReturn(predictResponse);
}

@Test
public void testGemma2PredictTpu() throws IOException {
Gemma2PredictTpu creator = new Gemma2PredictTpu(mockPredictionServiceClient);
String response = creator.gemma2PredictTpu(projectId, region, endpointId, parameters);

assertEquals(mockedResponse, response);
}

@Test
public void testGemma2PredictGpu() throws IOException {
Gemma2PredictGpu creator = new Gemma2PredictGpu(mockPredictionServiceClient);
String response = creator.gemma2PredictGpu(projectId, region, endpointId, parameters);

assertEquals(mockedResponse, response);
}
}
109 changes: 109 additions & 0 deletions aiplatform/src/test/java/aiplatform/Gemma2PredictionsTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package aiplatform;

import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.*;

import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.stub.PredictionServiceStub;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;


public class Gemma2PredictionsTest {

private static final String RESPONSE = "The sky appears blue due to a phenomenon called **Rayleigh scattering**.\n" +
"**Here's how it works:**\n" +
"1. **Sunlight:** Sunlight is composed of all the colors of the rainbow.\n" +
"2. **Earth's Atmosphere:** When sunlight enters the Earth's atmosphere, it collides with tiny particles like nitrogen and oxygen molecules.\n" +
"3. **Scattering:** These particles scatter the sunlight in all directions. However, blue light (which has a shorter wavelength) is scattered more effectively than other colors.\n" +
"4. **Our Perception:** As a result, we see a blue sky because the scattered blue light reaches our eyes from all directions.\n" +
"**Why not other colors?**\n" +
"* **Violet light** has an even shorter wavelength than blue and is scattered even more. However, our eyes are less sensitive to violet light, so we perceive the sky as blue.\n" +
"* **Longer wavelengths** like red, orange, and yellow are scattered less and travel more directly through the atmosphere. This is why we see these colors during sunrise and sunset, when sunlight has to travel through more of the atmosphere.\n";

private static final String PROJECT_ID = "rsamborski-ai-hypercomputer";
private static final String GPU_ENDPOINT_REGION = "us-east1";
private static final String GPU_ENDPOINT_ID = "323876543124209664"; // Mock ID used to check if GPU was called
private static final String TPU_ENDPOINT_REGION = "us-west1";
private static final String TPU_ENDPOINT_ID = "9194824316951199744";
private static final String PARAMETERS =
"{\n"
+ " \"temperature\": 0.3,\n"
+ " \"maxOutputTokens\": 200,\n"
+ " \"topP\": 0.8,\n"
+ " \"topK\": 40\n"
+ "}";
private final PredictionServiceStub mockStub = Mockito.mock(PredictionServiceStub.class);
PredictionServiceClient client = PredictionServiceClient.create(mockStub);

@Test
public void testShouldRunInterferenceWithGPU() throws IOException {
PredictionServiceStub mockStub = Mockito.mock(PredictionServiceStub.class);
PredictionServiceClient client = PredictionServiceClient.create(mockStub);
String instance = "{ \"inputs\": \"Why is the sky blue?\"}";
Value.Builder instanceValue = Value.newBuilder();
JsonFormat.parser().merge(instance, instanceValue);

Value.Builder parameterValueBuilder = Value.newBuilder();
JsonFormat.parser().merge(PARAMETERS, parameterValueBuilder);
Value parameterValue = parameterValueBuilder.build();

List<Value> instances = new ArrayList<>();
instances.add(instanceValue.build());

PredictResponse mockResponse = PredictResponse.newBuilder()
.addPredictions(Value.newBuilder().setStringValue(RESPONSE).build())
.build();

Mockito.when(client.predict(GPU_ENDPOINT_ID, instances, parameterValue)).thenReturn(mockResponse);
// NullPointerException
Gemma2PredictGpu gemma2PredictGpu = new Gemma2PredictGpu(client);

String output = gemma2PredictGpu.gemma2PredictGpu(PROJECT_ID, GPU_ENDPOINT_REGION, GPU_ENDPOINT_ID, PARAMETERS);

assertTrue(output.contains("Rayleigh scattering"));
verify(client, times(1)).predict(GPU_ENDPOINT_ID, instances, parameterValue);
assertTrue(instances.contains("inputs"));
assertTrue(parameterValue.getStructValue().containsFields("temperature"));
assertTrue(parameterValue.getStructValue().containsFields("maxOutputTokens"));
assertTrue(parameterValue.getStructValue().containsFields("topP"));
assertTrue(parameterValue.getStructValue().containsFields("topK"));
}

@Test
public void testShouldRunInterferenceWithTPU() throws IOException {

String instance = "{ \"prompt\": \"Why is the sky blue?\"}";
Value.Builder instanceValue = Value.newBuilder();
JsonFormat.parser().merge(instance, instanceValue);

Value.Builder parameterValueBuilder = Value.newBuilder();
JsonFormat.parser().merge(PARAMETERS, parameterValueBuilder);
Value parameterValue = parameterValueBuilder.build();

List<Value> instances = new ArrayList<>();
instances.add(instanceValue.build());

PredictResponse mockResponse = PredictResponse.newBuilder()
.addPredictions(Value.newBuilder().setStringValue(RESPONSE).build())
.build();
when(client.predict(TPU_ENDPOINT_ID, instances, parameterValue)).thenReturn(mockResponse);
// NullPointerException
Gemma2PredictTpu gemma2PredictTpu = new Gemma2PredictTpu(client);
String output = gemma2PredictTpu.gemma2PredictTpu(PROJECT_ID,TPU_ENDPOINT_REGION, TPU_ENDPOINT_ID, PARAMETERS);

assertTrue(output.contains("Rayleigh scattering"));
verify(client, times(1)).predict(GPU_ENDPOINT_ID, instances, parameterValue);
assertTrue(instances.contains("prompt"));
assertTrue(parameterValue.getStructValue().containsFields("temperature"));
assertTrue(parameterValue.getStructValue().containsFields("maxOutputTokens"));
assertTrue(parameterValue.getStructValue().containsFields("topP"));
assertTrue(parameterValue.getStructValue().containsFields("topK"));
}
}