Skip to content

Commit 6c42138

Browse files
committed
correcting according to comments
1 parent 216227f commit 6c42138

File tree

4 files changed

+50
-17
lines changed

4 files changed

+50
-17
lines changed

gemma2/gemma2_predict_gpu.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,6 @@ import (
2626
"google.golang.org/protobuf/types/known/structpb"
2727
)
2828

29-
type ClientInterface interface {
30-
Close() error
31-
Predict(ctx context.Context, req *aiplatformpb.PredictRequest, opts ...gax.CallOption) (*aiplatformpb.PredictResponse, error)
32-
}
33-
3429
// predictGPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with GPU accelerators.
3530
func predictGPU(w io.Writer, client ClientInterface, projectID, location, endpointID string) error {
3631
ctx := context.Background()
@@ -54,6 +49,7 @@ func predictGPU(w io.Writer, client ClientInterface, projectID, location, endpoi
5449

5550
// Encapsulate the prompt in a correct format for TPUs.
5651
// Pay attention that prompt should be set in "inputs" field.
52+
// Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.9}}]
5753
promptValue, err := structpb.NewValue(map[string]interface{}{
5854
"inputs": prompt,
5955
"parameters": parameters,
@@ -81,3 +77,8 @@ func predictGPU(w io.Writer, client ClientInterface, projectID, location, endpoi
8177
}
8278

8379
// [END generativeaionvertexai_gemma2_predict_gpu]
80+
81+
type ClientInterface interface {
82+
Close() error
83+
Predict(ctx context.Context, req *aiplatformpb.PredictRequest, opts ...gax.CallOption) (*aiplatformpb.PredictResponse, error)
84+
}

gemma2/gemma2_predict_tpu.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,6 @@ import (
2626
"google.golang.org/protobuf/types/known/structpb"
2727
)
2828

29-
type PredictClientInterface interface {
30-
Close() error
31-
Predict(ctx context.Context, req *aiplatformpb.PredictRequest, opts ...gax.CallOption) (*aiplatformpb.PredictResponse, error)
32-
}
33-
3429
// predictTPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accelerators.
3530
func predictTPU(w io.Writer, client PredictClientInterface, projectID, location, endpointID string) error {
3631
ctx := context.Background()
@@ -53,6 +48,7 @@ func predictTPU(w io.Writer, client PredictClientInterface, projectID, location,
5348
}
5449

5550
// Encapsulate the prompt in a correct format for TPUs.
51+
// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
5652
promptValue, err := structpb.NewValue(map[string]interface{}{
5753
"prompt": prompt,
5854
"parameters": parameters,
@@ -80,3 +76,8 @@ func predictTPU(w io.Writer, client PredictClientInterface, projectID, location,
8076
}
8177

8278
// [END generativeaionvertexai_gemma2_predict_tpu]
79+
80+
type PredictClientInterface interface {
81+
Close() error
82+
Predict(ctx context.Context, req *aiplatformpb.PredictRequest, opts ...gax.CallOption) (*aiplatformpb.PredictResponse, error)
83+
}

gemma2/gemma2_test.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,14 @@ import (
2222
"github.com/GoogleCloudPlatform/golang-samples/internal/testutil"
2323
)
2424

25-
func TestPredictGPU(t *testing.T) {
25+
const (
26+
GPUEndpointID string = "123456789"
27+
GPUEndpointRegion = "us-east1"
28+
TPUEndpointID = "987654321"
29+
TPUEndpointRegion = "us-west1"
30+
)
31+
32+
func TestPredictGemma2(t *testing.T) {
2633
tc := testutil.SystemTest(t)
2734

2835
projectID := tc.ProjectID
@@ -32,9 +39,7 @@ func TestPredictGPU(t *testing.T) {
3239
t.Run("GPU predict", func(t *testing.T) {
3340
buf.Reset()
3441
// Mock ID used to check if GPU was called
35-
endpointID := "123456789"
36-
location := "us-east4"
37-
if err := predictGPU(&buf, client, projectID, location, endpointID); err != nil {
42+
if err := predictGPU(&buf, client, projectID, GPUEndpointRegion, GPUEndpointID); err != nil {
3843
t.Fatal(err)
3944
}
4045

@@ -46,9 +51,7 @@ func TestPredictGPU(t *testing.T) {
4651
t.Run("TPU predict", func(t *testing.T) {
4752
buf.Reset()
4853
// Mock ID used to check if TPU was called
49-
endpointID := "123456789"
50-
location := "us-west1"
51-
if err := predictTPU(&buf, client, projectID, location, endpointID); err != nil {
54+
if err := predictTPU(&buf, client, projectID, TPUEndpointRegion, TPUEndpointID); err != nil {
5255
t.Fatal(err)
5356
}
5457

gemma2/mock_client.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ package snippets
1616

1717
import (
1818
"context"
19+
"fmt"
20+
"strings"
1921

2022
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
2123
gax "github.com/googleapis/gax-go/v2"
@@ -43,5 +45,31 @@ func (client PredictionsClient) Predict(ctx context.Context, req *aiplatformpb.P
4345
response := &aiplatformpb.PredictResponse{
4446
Predictions: []*structpb.Value{structpb.NewStringValue(mockedResponse)},
4547
}
48+
49+
instance := req.Instances[0].GetStructValue()
50+
if ok := strings.Contains(req.Endpoint, fmt.Sprintf("locations/%s/endpoints/%s", GPUEndpointRegion, GPUEndpointID)); ok {
51+
if err := instance.Fields["inputs"].GetStringValue(); err == "" {
52+
return nil, fmt.Errorf("invalid request")
53+
}
54+
} else if ok := strings.Contains(req.Endpoint, fmt.Sprintf("locations/%s/endpoints/%s", TPUEndpointRegion, TPUEndpointID)); ok {
55+
if err := instance.Fields["prompt"].GetStringValue(); err == "" {
56+
return nil, fmt.Errorf("invalid request")
57+
}
58+
}
59+
60+
params := req.Instances[0].GetStructValue().Fields["parameters"].GetStructValue()
61+
62+
if err := params.Fields["temperature"].GetNumberValue(); err == 0 {
63+
return nil, fmt.Errorf("invalid request")
64+
}
65+
if err := params.Fields["maxOutputTokens"].GetNumberValue(); err == 0 {
66+
return nil, fmt.Errorf("invalid request")
67+
}
68+
if err := params.Fields["topP"].GetNumberValue(); err == 0 {
69+
return nil, fmt.Errorf("invalid request")
70+
}
71+
if err := params.Fields["topK"].GetNumberValue(); err == 0 {
72+
return nil, fmt.Errorf("invalid request")
73+
}
4674
return response, nil
4775
}

0 commit comments

Comments
 (0)