Skip to content
Merged
6 changes: 6 additions & 0 deletions aiplatform/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@
<version>1.7.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>5.13.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
Expand Down
90 changes: 90 additions & 0 deletions aiplatform/src/main/java/aiplatform/Gemma2PredictGpu.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package aiplatform;

// [START generativeaionvertexai_gemma2_predict_gpu]

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class Gemma2PredictGpu {

private final PredictionServiceClient predictionServiceClient;

// Constructor to inject the PredictionServiceClient
public Gemma2PredictGpu(PredictionServiceClient predictionServiceClient) {
this.predictionServiceClient = predictionServiceClient;
}

public static void main(String[] args) throws IOException {
// TODO(developer): Replace these variables before running the sample.
String projectId = "YOUR_PROJECT_ID";
String region = "us-east4";
String endpointId = "YOUR_ENDPOINT_ID";
String parameters =
"{\n"
+ " \"temperature\": 0.3,\n"
+ " \"maxDecodeSteps\": 200,\n"
+ " \"topP\": 0.8,\n"
+ " \"topK\": 40\n"
+ "}";

PredictionServiceSettings predictionServiceSettings =
PredictionServiceSettings.newBuilder()
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", region))
.build();
PredictionServiceClient predictionServiceClient =
PredictionServiceClient.create(predictionServiceSettings);
Gemma2PredictGpu creator = new Gemma2PredictGpu(predictionServiceClient);

creator.gemma2PredictGpu(projectId, region, endpointId, parameters);
}

// Demonstrates how to run interference on a Gemma2 model
// deployed to a Vertex AI endpoint with GPU accelerators.
public String gemma2PredictGpu(String projectId, String region,
String endpointId, String parameters) throws IOException {
// Prompt used in the prediction
String instance = "{ \"inputs\": \"Why is the sky blue?\"}";

Value.Builder parameterValueBuilder = Value.newBuilder();
JsonFormat.parser().merge(parameters, parameterValueBuilder);
Value parameterValue = parameterValueBuilder.build();

Value.Builder instanceValue = Value.newBuilder();
JsonFormat.parser().merge(instance, instanceValue);

List<Value> instances = new ArrayList<>();
instances.add(instanceValue.build());

EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

PredictResponse predictResponse = this.predictionServiceClient
.predict(endpointName, instances, parameterValue);
String textResponse = predictResponse.getPredictions(0).getStringValue();
System.out.println(textResponse);
return textResponse;
}
}
// [END generativeaionvertexai_gemma2_predict_gpu]
90 changes: 90 additions & 0 deletions aiplatform/src/main/java/aiplatform/Gemma2PredictTpu.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package aiplatform;

// [START generativeaionvertexai_gemma2_predict_tpu]

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class Gemma2PredictTpu {
private final PredictionServiceClient predictionServiceClient;

// Constructor to inject the PredictionServiceClient
public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
this.predictionServiceClient = predictionServiceClient;
}

public static void main(String[] args) throws IOException {
// TODO(developer): Replace these variables before running the sample.
String projectId = "YOUR_PROJECT_ID";
String region = "us-west1";
String endpointId = "YOUR_ENDPOINT_ID";
String parameters =
"{\n"
+ " \"temperature\": 0.3,\n"
+ " \"maxDecodeSteps\": 200,\n"
+ " \"topP\": 0.8,\n"
+ " \"topK\": 40\n"
+ "}";

PredictionServiceSettings predictionServiceSettings =
PredictionServiceSettings.newBuilder()
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", region))
.build();
PredictionServiceClient predictionServiceClient =
PredictionServiceClient.create(predictionServiceSettings);
Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);

creator.gemma2PredictTpu(projectId, region, endpointId, parameters);
}

// Demonstrates how to run interference on a Gemma2 model
// deployed to a Vertex AI endpoint with TPU accelerators.
public String gemma2PredictTpu(String projectId, String region,
String endpointId, String parameters) throws IOException {
// Prompt used in the prediction
String instance = "{ \"prompt\": \"Why is the sky blue?\"}";

Value.Builder parameterValueBuilder = Value.newBuilder();
JsonFormat.parser().merge(parameters, parameterValueBuilder);
Value parameterValue = parameterValueBuilder.build();

Value.Builder instanceValue = Value.newBuilder();
JsonFormat.parser().merge(instance, instanceValue);

List<Value> instances = new ArrayList<>();
instances.add(instanceValue.build());

EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

PredictResponse predictResponse = this.predictionServiceClient
.predict(endpointName, instances, parameterValue);
String textResponse = predictResponse.getPredictions(0).getStringValue();
System.out.println(textResponse);
return textResponse;
}
}
// [END generativeaionvertexai_gemma2_predict_tpu]

76 changes: 76 additions & 0 deletions aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package aiplatform;

import static org.junit.jupiter.api.Assertions.assertEquals;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.protobuf.Value;
import java.io.IOException;
import java.util.List;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

public class Gemma2PredictTest {
static String mockedResponse = "The sky appears blue due to a phenomenon "
+ "called **Rayleigh scattering**.\n"
+ "**Here's how it works:**\n"
+ "* **Sunlight is white:** Sunlight actually contains all the colors of the rainbow.\n"
+ "* **Scattering:** When sunlight enters the Earth's atmosphere, it collides with tiny gas"
+ " molecules (mostly nitrogen and oxygen). These collisions cause the light to scatter "
+ "in different directions.\n"
+ "* **Blue light scatters most:** Blue light has a shorter wavelength";
String projectId = "your-project-id";
String region = "us-central1";
String endpointId = "your-endpoint-id";
String parameters = "{}";
static PredictionServiceClient mockPredictionServiceClient;

@BeforeAll
public static void setUp() {
// Mock PredictionServiceClient and its response
mockPredictionServiceClient = Mockito.mock(PredictionServiceClient.class);
PredictResponse predictResponse =
PredictResponse.newBuilder()
.addPredictions(Value.newBuilder().setStringValue(mockedResponse).build())
.build();
Mockito.when(mockPredictionServiceClient.predict(
Mockito.any(EndpointName.class),
Mockito.any(List.class),
Mockito.any(Value.class)))
.thenReturn(predictResponse);
}

@Test
public void testGemma2PredictTpu() throws IOException {
Gemma2PredictTpu creator = new Gemma2PredictTpu(mockPredictionServiceClient);
String response = creator.gemma2PredictTpu(projectId, region, endpointId, parameters);

assertEquals(mockedResponse, response);
}

@Test
public void testGemma2PredictGpu() throws IOException {
Gemma2PredictGpu creator = new Gemma2PredictGpu(mockPredictionServiceClient);
String response = creator.gemma2PredictGpu(projectId, region, endpointId, parameters);

assertEquals(mockedResponse, response);
}
}