Skip to content

Commit 9f7d7cd

Browse files
committed
correcting according to comments
1 parent 216227f commit 9f7d7cd

File tree

4 files changed

+50
-17
lines changed

4 files changed

+50
-17
lines changed

gemma2/gemma2_predict_gpu.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,6 @@ import (
2626
"google.golang.org/protobuf/types/known/structpb"
2727
)
2828

29-
type ClientInterface interface {
30-
Close() error
31-
Predict(ctx context.Context, req *aiplatformpb.PredictRequest, opts ...gax.CallOption) (*aiplatformpb.PredictResponse, error)
32-
}
33-
3429
// predictGPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with GPU accelerators.
3530
func predictGPU(w io.Writer, client ClientInterface, projectID, location, endpointID string) error {
3631
ctx := context.Background()
@@ -54,6 +49,7 @@ func predictGPU(w io.Writer, client ClientInterface, projectID, location, endpoi
5449

5550
// Encapsulate the prompt in a correct format for TPUs.
5651
// Pay attention that prompt should be set in "inputs" field.
52+
// Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.9}}]
5753
promptValue, err := structpb.NewValue(map[string]interface{}{
5854
"inputs": prompt,
5955
"parameters": parameters,
@@ -81,3 +77,8 @@ func predictGPU(w io.Writer, client ClientInterface, projectID, location, endpoi
8177
}
8278

8379
// [END generativeaionvertexai_gemma2_predict_gpu]
80+
81+
type ClientInterface interface {
82+
Close() error
83+
Predict(ctx context.Context, req *aiplatformpb.PredictRequest, opts ...gax.CallOption) (*aiplatformpb.PredictResponse, error)
84+
}

gemma2/gemma2_predict_tpu.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,6 @@ import (
2626
"google.golang.org/protobuf/types/known/structpb"
2727
)
2828

29-
type PredictClientInterface interface {
30-
Close() error
31-
Predict(ctx context.Context, req *aiplatformpb.PredictRequest, opts ...gax.CallOption) (*aiplatformpb.PredictResponse, error)
32-
}
33-
3429
// predictTPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accelerators.
3530
func predictTPU(w io.Writer, client PredictClientInterface, projectID, location, endpointID string) error {
3631
ctx := context.Background()
@@ -53,6 +48,7 @@ func predictTPU(w io.Writer, client PredictClientInterface, projectID, location,
5348
}
5449

5550
// Encapsulate the prompt in a correct format for TPUs.
51+
// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
5652
promptValue, err := structpb.NewValue(map[string]interface{}{
5753
"prompt": prompt,
5854
"parameters": parameters,
@@ -80,3 +76,8 @@ func predictTPU(w io.Writer, client PredictClientInterface, projectID, location,
8076
}
8177

8278
// [END generativeaionvertexai_gemma2_predict_tpu]
79+
80+
type PredictClientInterface interface {
81+
Close() error
82+
Predict(ctx context.Context, req *aiplatformpb.PredictRequest, opts ...gax.CallOption) (*aiplatformpb.PredictResponse, error)
83+
}

gemma2/gemma2_test.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import (
2222
"github.com/GoogleCloudPlatform/golang-samples/internal/testutil"
2323
)
2424

25-
func TestPredictGPU(t *testing.T) {
25+
func TestPredictGemma2(t *testing.T) {
2626
tc := testutil.SystemTest(t)
2727

2828
projectID := tc.ProjectID
@@ -32,9 +32,7 @@ func TestPredictGPU(t *testing.T) {
3232
t.Run("GPU predict", func(t *testing.T) {
3333
buf.Reset()
3434
// Mock ID used to check if GPU was called
35-
endpointID := "123456789"
36-
location := "us-east4"
37-
if err := predictGPU(&buf, client, projectID, location, endpointID); err != nil {
35+
if err := predictGPU(&buf, client, projectID, GPUEndpointRegion, GPUEndpointID); err != nil {
3836
t.Fatal(err)
3937
}
4038

@@ -46,9 +44,7 @@ func TestPredictGPU(t *testing.T) {
4644
t.Run("TPU predict", func(t *testing.T) {
4745
buf.Reset()
4846
// Mock ID used to check if TPU was called
49-
endpointID := "123456789"
50-
location := "us-west1"
51-
if err := predictTPU(&buf, client, projectID, location, endpointID); err != nil {
47+
if err := predictTPU(&buf, client, projectID, TPUEndpointRegion, TPUEndpointID); err != nil {
5248
t.Fatal(err)
5349
}
5450

gemma2/mock_client.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,21 @@ package snippets
1616

1717
import (
1818
"context"
19+
"fmt"
20+
"strings"
1921

2022
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
2123
gax "github.com/googleapis/gax-go/v2"
2224
"google.golang.org/protobuf/types/known/structpb"
2325
)
2426

27+
const (
28+
GPUEndpointID = "123456789"
29+
GPUEndpointRegion = "us-east1"
30+
TPUEndpointID = "987654321"
31+
TPUEndpointRegion = "us-west1"
32+
)
33+
2534
type PredictionsClient struct{}
2635

2736
func (client PredictionsClient) Close() error {
@@ -43,5 +52,31 @@ func (client PredictionsClient) Predict(ctx context.Context, req *aiplatformpb.P
4352
response := &aiplatformpb.PredictResponse{
4453
Predictions: []*structpb.Value{structpb.NewStringValue(mockedResponse)},
4554
}
55+
56+
instance := req.Instances[0].GetStructValue()
57+
if ok := strings.Contains(req.Endpoint, fmt.Sprintf("locations/%s/endpoints/%s", GPUEndpointRegion, GPUEndpointID)); ok {
58+
if err := instance.Fields["inputs"].GetStringValue(); err == "" {
59+
return nil, fmt.Errorf("invalid request")
60+
}
61+
} else if ok := strings.Contains(req.Endpoint, fmt.Sprintf("locations/%s/endpoints/%s", TPUEndpointRegion, TPUEndpointID)); ok {
62+
if err := instance.Fields["prompt"].GetStringValue(); err == "" {
63+
return nil, fmt.Errorf("invalid request")
64+
}
65+
}
66+
67+
params := req.Instances[0].GetStructValue().Fields["parameters"].GetStructValue()
68+
69+
if err := params.Fields["temperature"].GetNumberValue(); err == 0 {
70+
return nil, fmt.Errorf("invalid request")
71+
}
72+
if err := params.Fields["maxOutputTokens"].GetNumberValue(); err == 0 {
73+
return nil, fmt.Errorf("invalid request")
74+
}
75+
if err := params.Fields["topP"].GetNumberValue(); err == 0 {
76+
return nil, fmt.Errorf("invalid request")
77+
}
78+
if err := params.Fields["topK"].GetNumberValue(); err == 0 {
79+
return nil, fmt.Errorf("invalid request")
80+
}
4681
return response, nil
4782
}

0 commit comments

Comments
 (0)