-
Notifications
You must be signed in to change notification settings - Fork 2.9k
feat(aiplatform): add gemma2 samples for Model Garden deployments to VertexAI endpoints #9527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 8 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
b0e22b0
Added generativeaionvertexai_gemma2_predict_tpu and generativeaionver…
TetyanaYahodska 10b1bf3
Fixed instance format
TetyanaYahodska 3ac56f9
Fixed test and instance format for Gemma2PredictTpu
TetyanaYahodska 61e99b8
Merge changes from main
TetyanaYahodska 8632dbe
Deleted class for vertexai package
TetyanaYahodska 0c7608f
Added generativeaionvertexai_gemma2_predict_gpu and generativeaionver…
TetyanaYahodska f3538b8
Fixed comments
TetyanaYahodska fe54ef1
added comments, created new test
TetyanaYahodska 5ddc361
Fixed parameters, created test to check parameters
TetyanaYahodska a2e3f43
Fixed comments
TetyanaYahodska 60f1ed8
Fixed comments
TetyanaYahodska File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
/* | ||
* Copyright 2024 Google LLC | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package aiplatform; | ||
|
||
// [START generativeaionvertexai_gemma2_predict_gpu] | ||
|
||
import com.google.cloud.aiplatform.v1.EndpointName; | ||
import com.google.cloud.aiplatform.v1.PredictResponse; | ||
import com.google.cloud.aiplatform.v1.PredictionServiceClient; | ||
import com.google.cloud.aiplatform.v1.PredictionServiceSettings; | ||
import com.google.protobuf.Value; | ||
import com.google.protobuf.util.JsonFormat; | ||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
public class Gemma2PredictGpu { | ||
|
||
private final PredictionServiceClient predictionServiceClient; | ||
|
||
// Constructor to inject the PredictionServiceClient | ||
public Gemma2PredictGpu(PredictionServiceClient predictionServiceClient) { | ||
this.predictionServiceClient = predictionServiceClient; | ||
} | ||
|
||
public static void main(String[] args) throws IOException { | ||
// TODO(developer): Replace these variables before running the sample. | ||
String projectId = "YOUR_PROJECT_ID"; | ||
String endpointRegion = "us-east4"; | ||
String endpointId = "YOUR_ENDPOINT_ID"; | ||
String parameters = | ||
"{\n" | ||
+ " \"temperature\": 0.9,\n" | ||
+ " \"maxOutputTokens\": 1024,\n" | ||
+ " \"topP\": 1.0,\n" | ||
+ " \"topK\": 1\n" | ||
+ "}"; | ||
|
||
PredictionServiceSettings predictionServiceSettings = | ||
PredictionServiceSettings.newBuilder() | ||
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion)) | ||
.build(); | ||
PredictionServiceClient predictionServiceClient = | ||
PredictionServiceClient.create(predictionServiceSettings); | ||
TetyanaYahodska marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Gemma2PredictGpu creator = new Gemma2PredictGpu(predictionServiceClient); | ||
|
||
creator.gemma2PredictGpu(projectId, endpointRegion, endpointId, parameters); | ||
} | ||
|
||
// Demonstrates how to run interference on a Gemma2 model | ||
TetyanaYahodska marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
// deployed to a Vertex AI endpoint with GPU accelerators. | ||
public String gemma2PredictGpu(String projectId, String region, | ||
String endpointId, String parameters) throws IOException { | ||
// Prompt used in the prediction | ||
String instance = "{ \"inputs\": \"Why is the sky blue?\"}"; | ||
|
||
Value.Builder parameterValueBuilder = Value.newBuilder(); | ||
JsonFormat.parser().merge(parameters, parameterValueBuilder); | ||
Value parameterValue = parameterValueBuilder.build(); | ||
|
||
Value.Builder instanceValue = Value.newBuilder(); | ||
JsonFormat.parser().merge(instance, instanceValue); | ||
|
||
// Encapsulate the prompt in a correct format for GPUs | ||
// Example format: [{'prompt': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}] | ||
TetyanaYahodska marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
List<Value> instances = new ArrayList<>(); | ||
instances.add(instanceValue.build()); | ||
rsamborski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
EndpointName endpointName = EndpointName.of(projectId, region, endpointId); | ||
|
||
PredictResponse predictResponse = this.predictionServiceClient | ||
.predict(endpointName, instances, parameterValue); | ||
String textResponse = predictResponse.getPredictions(0).getStringValue(); | ||
System.out.println(textResponse); | ||
return textResponse; | ||
} | ||
} | ||
// [END generativeaionvertexai_gemma2_predict_gpu] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
/* | ||
* Copyright 2024 Google LLC | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package aiplatform; | ||
|
||
// [START generativeaionvertexai_gemma2_predict_tpu] | ||
|
||
import com.google.cloud.aiplatform.v1.EndpointName; | ||
import com.google.cloud.aiplatform.v1.PredictResponse; | ||
import com.google.cloud.aiplatform.v1.PredictionServiceClient; | ||
import com.google.cloud.aiplatform.v1.PredictionServiceSettings; | ||
import com.google.protobuf.Value; | ||
import com.google.protobuf.util.JsonFormat; | ||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
public class Gemma2PredictTpu { | ||
private final PredictionServiceClient predictionServiceClient; | ||
|
||
// Constructor to inject the PredictionServiceClient | ||
public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) { | ||
this.predictionServiceClient = predictionServiceClient; | ||
} | ||
TetyanaYahodska marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
public static void main(String[] args) throws IOException { | ||
// TODO(developer): Replace these variables before running the sample. | ||
String projectId = "YOUR_PROJECT_ID"; | ||
String endpointRegion = "us-west1"; | ||
String endpointId = "YOUR_ENDPOINT_ID"; | ||
String parameters = | ||
"{\n" | ||
+ " \"temperature\": 0.9,\n" | ||
+ " \"maxOutputTokens\": 1024,\n" | ||
+ " \"topP\": 1.0,\n" | ||
+ " \"topK\": 1\n" | ||
+ "}"; | ||
|
||
PredictionServiceSettings predictionServiceSettings = | ||
PredictionServiceSettings.newBuilder() | ||
.setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion)) | ||
.build(); | ||
PredictionServiceClient predictionServiceClient = | ||
PredictionServiceClient.create(predictionServiceSettings); | ||
Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient); | ||
|
||
creator.gemma2PredictTpu(projectId, endpointRegion, endpointId, parameters); | ||
} | ||
|
||
// Demonstrates how to run interference on a Gemma2 model | ||
// deployed to a Vertex AI endpoint with TPU accelerators. | ||
public String gemma2PredictTpu(String projectId, String region, | ||
String endpointId, String parameters) throws IOException { | ||
// Prompt used in the prediction | ||
String instance = "{ \"prompt\": \"Why is the sky blue?\"}"; | ||
|
||
Value.Builder parameterValueBuilder = Value.newBuilder(); | ||
JsonFormat.parser().merge(parameters, parameterValueBuilder); | ||
Value parameterValue = parameterValueBuilder.build(); | ||
|
||
Value.Builder instanceValue = Value.newBuilder(); | ||
JsonFormat.parser().merge(instance, instanceValue); | ||
|
||
// Encapsulate the prompt in a correct format for TPUs | ||
// Example format: [{'prompt': 'Why is the sky blue?', 'parameters': {'temperature': 0.8}}] | ||
TetyanaYahodska marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
List<Value> instances = new ArrayList<>(); | ||
instances.add(instanceValue.build()); | ||
TetyanaYahodska marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
EndpointName endpointName = EndpointName.of(projectId, region, endpointId); | ||
|
||
PredictResponse predictResponse = this.predictionServiceClient | ||
.predict(endpointName, instances, parameterValue); | ||
String textResponse = predictResponse.getPredictions(0).getStringValue(); | ||
System.out.println(textResponse); | ||
return textResponse; | ||
} | ||
} | ||
// [END generativeaionvertexai_gemma2_predict_tpu] | ||
|
76 changes: 76 additions & 0 deletions
76
aiplatform/src/test/java/aiplatform/Gemma2PredictTest.java
TetyanaYahodska marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
/* | ||
* Copyright 2024 Google LLC | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package aiplatform; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
|
||
import com.google.cloud.aiplatform.v1.EndpointName; | ||
import com.google.cloud.aiplatform.v1.PredictResponse; | ||
import com.google.cloud.aiplatform.v1.PredictionServiceClient; | ||
import com.google.protobuf.Value; | ||
import java.io.IOException; | ||
import java.util.List; | ||
import org.junit.jupiter.api.BeforeAll; | ||
import org.junit.jupiter.api.Test; | ||
import org.mockito.Mockito; | ||
|
||
public class Gemma2PredictTest { | ||
static String mockedResponse = "The sky appears blue due to a phenomenon " | ||
+ "called **Rayleigh scattering**.\n" | ||
+ "**Here's how it works:**\n" | ||
+ "* **Sunlight is white:** Sunlight actually contains all the colors of the rainbow.\n" | ||
+ "* **Scattering:** When sunlight enters the Earth's atmosphere, it collides with tiny gas" | ||
+ " molecules (mostly nitrogen and oxygen). These collisions cause the light to scatter " | ||
+ "in different directions.\n" | ||
+ "* **Blue light scatters most:** Blue light has a shorter wavelength"; | ||
String projectId = "your-project-id"; | ||
String region = "us-central1"; | ||
String endpointId = "your-endpoint-id"; | ||
String parameters = "{}"; | ||
static PredictionServiceClient mockPredictionServiceClient; | ||
|
||
@BeforeAll | ||
public static void setUp() { | ||
// Mock PredictionServiceClient and its response | ||
mockPredictionServiceClient = Mockito.mock(PredictionServiceClient.class); | ||
PredictResponse predictResponse = | ||
PredictResponse.newBuilder() | ||
.addPredictions(Value.newBuilder().setStringValue(mockedResponse).build()) | ||
.build(); | ||
Mockito.when(mockPredictionServiceClient.predict( | ||
Mockito.any(EndpointName.class), | ||
Mockito.any(List.class), | ||
Mockito.any(Value.class))) | ||
.thenReturn(predictResponse); | ||
} | ||
|
||
@Test | ||
public void testGemma2PredictTpu() throws IOException { | ||
Gemma2PredictTpu creator = new Gemma2PredictTpu(mockPredictionServiceClient); | ||
String response = creator.gemma2PredictTpu(projectId, region, endpointId, parameters); | ||
|
||
assertEquals(mockedResponse, response); | ||
TetyanaYahodska marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
@Test | ||
public void testGemma2PredictGpu() throws IOException { | ||
Gemma2PredictGpu creator = new Gemma2PredictGpu(mockPredictionServiceClient); | ||
String response = creator.gemma2PredictGpu(projectId, region, endpointId, parameters); | ||
|
||
assertEquals(mockedResponse, response); | ||
} | ||
} |
109 changes: 109 additions & 0 deletions
109
aiplatform/src/test/java/aiplatform/Gemma2PredictionsTest.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
package aiplatform; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertTrue; | ||
import static org.mockito.Mockito.*; | ||
|
||
import com.google.cloud.aiplatform.v1.PredictResponse; | ||
import com.google.cloud.aiplatform.v1.PredictionServiceClient; | ||
import com.google.cloud.aiplatform.v1.stub.PredictionServiceStub; | ||
import com.google.protobuf.Value; | ||
import com.google.protobuf.util.JsonFormat; | ||
import org.junit.jupiter.api.Test; | ||
import org.mockito.Mockito; | ||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
|
||
public class Gemma2PredictionsTest { | ||
|
||
private static final String RESPONSE = "The sky appears blue due to a phenomenon called **Rayleigh scattering**.\n" + | ||
"**Here's how it works:**\n" + | ||
"1. **Sunlight:** Sunlight is composed of all the colors of the rainbow.\n" + | ||
"2. **Earth's Atmosphere:** When sunlight enters the Earth's atmosphere, it collides with tiny particles like nitrogen and oxygen molecules.\n" + | ||
"3. **Scattering:** These particles scatter the sunlight in all directions. However, blue light (which has a shorter wavelength) is scattered more effectively than other colors.\n" + | ||
"4. **Our Perception:** As a result, we see a blue sky because the scattered blue light reaches our eyes from all directions.\n" + | ||
"**Why not other colors?**\n" + | ||
"* **Violet light** has an even shorter wavelength than blue and is scattered even more. However, our eyes are less sensitive to violet light, so we perceive the sky as blue.\n" + | ||
"* **Longer wavelengths** like red, orange, and yellow are scattered less and travel more directly through the atmosphere. This is why we see these colors during sunrise and sunset, when sunlight has to travel through more of the atmosphere.\n"; | ||
|
||
private static final String PROJECT_ID = "rsamborski-ai-hypercomputer"; | ||
private static final String GPU_ENDPOINT_REGION = "us-east1"; | ||
private static final String GPU_ENDPOINT_ID = "323876543124209664"; // Mock ID used to check if GPU was called | ||
private static final String TPU_ENDPOINT_REGION = "us-west1"; | ||
private static final String TPU_ENDPOINT_ID = "9194824316951199744"; | ||
private static final String PARAMETERS = | ||
"{\n" | ||
+ " \"temperature\": 0.3,\n" | ||
+ " \"maxOutputTokens\": 200,\n" | ||
+ " \"topP\": 0.8,\n" | ||
+ " \"topK\": 40\n" | ||
+ "}"; | ||
private final PredictionServiceStub mockStub = Mockito.mock(PredictionServiceStub.class); | ||
PredictionServiceClient client = PredictionServiceClient.create(mockStub); | ||
|
||
@Test | ||
public void testShouldRunInterferenceWithGPU() throws IOException { | ||
PredictionServiceStub mockStub = Mockito.mock(PredictionServiceStub.class); | ||
PredictionServiceClient client = PredictionServiceClient.create(mockStub); | ||
String instance = "{ \"inputs\": \"Why is the sky blue?\"}"; | ||
Value.Builder instanceValue = Value.newBuilder(); | ||
JsonFormat.parser().merge(instance, instanceValue); | ||
|
||
Value.Builder parameterValueBuilder = Value.newBuilder(); | ||
JsonFormat.parser().merge(PARAMETERS, parameterValueBuilder); | ||
Value parameterValue = parameterValueBuilder.build(); | ||
|
||
List<Value> instances = new ArrayList<>(); | ||
instances.add(instanceValue.build()); | ||
|
||
PredictResponse mockResponse = PredictResponse.newBuilder() | ||
.addPredictions(Value.newBuilder().setStringValue(RESPONSE).build()) | ||
.build(); | ||
|
||
Mockito.when(client.predict(GPU_ENDPOINT_ID, instances, parameterValue)).thenReturn(mockResponse); | ||
// NullPointerException | ||
Gemma2PredictGpu gemma2PredictGpu = new Gemma2PredictGpu(client); | ||
|
||
String output = gemma2PredictGpu.gemma2PredictGpu(PROJECT_ID, GPU_ENDPOINT_REGION, GPU_ENDPOINT_ID, PARAMETERS); | ||
|
||
assertTrue(output.contains("Rayleigh scattering")); | ||
verify(client, times(1)).predict(GPU_ENDPOINT_ID, instances, parameterValue); | ||
assertTrue(instances.contains("inputs")); | ||
assertTrue(parameterValue.getStructValue().containsFields("temperature")); | ||
assertTrue(parameterValue.getStructValue().containsFields("maxOutputTokens")); | ||
assertTrue(parameterValue.getStructValue().containsFields("topP")); | ||
assertTrue(parameterValue.getStructValue().containsFields("topK")); | ||
} | ||
|
||
@Test | ||
public void testShouldRunInterferenceWithTPU() throws IOException { | ||
|
||
String instance = "{ \"prompt\": \"Why is the sky blue?\"}"; | ||
Value.Builder instanceValue = Value.newBuilder(); | ||
JsonFormat.parser().merge(instance, instanceValue); | ||
|
||
Value.Builder parameterValueBuilder = Value.newBuilder(); | ||
JsonFormat.parser().merge(PARAMETERS, parameterValueBuilder); | ||
Value parameterValue = parameterValueBuilder.build(); | ||
|
||
List<Value> instances = new ArrayList<>(); | ||
instances.add(instanceValue.build()); | ||
|
||
PredictResponse mockResponse = PredictResponse.newBuilder() | ||
.addPredictions(Value.newBuilder().setStringValue(RESPONSE).build()) | ||
.build(); | ||
when(client.predict(TPU_ENDPOINT_ID, instances, parameterValue)).thenReturn(mockResponse); | ||
// NullPointerException | ||
Gemma2PredictTpu gemma2PredictTpu = new Gemma2PredictTpu(client); | ||
String output = gemma2PredictTpu.gemma2PredictTpu(PROJECT_ID,TPU_ENDPOINT_REGION, TPU_ENDPOINT_ID, PARAMETERS); | ||
|
||
assertTrue(output.contains("Rayleigh scattering")); | ||
verify(client, times(1)).predict(GPU_ENDPOINT_ID, instances, parameterValue); | ||
assertTrue(instances.contains("prompt")); | ||
assertTrue(parameterValue.getStructValue().containsFields("temperature")); | ||
assertTrue(parameterValue.getStructValue().containsFields("maxOutputTokens")); | ||
assertTrue(parameterValue.getStructValue().containsFields("topP")); | ||
assertTrue(parameterValue.getStructValue().containsFields("topK")); | ||
} | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.