Skip to content

Commit 8957a58

Browse files
authored
feat: gemma2 samples with accelerated TPU and GPU (#4395)
1 parent fc150e3 commit 8957a58

File tree

6 files changed

+521
-0
lines changed

6 files changed

+521
-0
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package snippets
16+
17+
// [START generativeaionvertexai_gemma2_predict_gpu]
18+
import (
19+
"context"
20+
"fmt"
21+
"io"
22+
23+
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
24+
25+
"google.golang.org/protobuf/types/known/structpb"
26+
)
27+
28+
// predictGPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with GPU accelerators.
29+
func predictGPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
30+
ctx := context.Background()
31+
32+
// Note: client can be initialized in the following way:
33+
// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
34+
// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
35+
// if err != nil {
36+
// return fmt.Errorf("unable to create prediction client: %v", err)
37+
// }
38+
// defer client.Close()
39+
40+
gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
41+
prompt := "Why is the sky blue?"
42+
parameters := map[string]interface{}{
43+
"temperature": 0.9,
44+
"maxOutputTokens": 1024,
45+
"topP": 1.0,
46+
"topK": 1,
47+
}
48+
49+
// Encapsulate the prompt in a correct format for TPUs.
50+
// Pay attention that prompt should be set in "inputs" field.
51+
// Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.9}}]
52+
promptValue, err := structpb.NewValue(map[string]interface{}{
53+
"inputs": prompt,
54+
"parameters": parameters,
55+
})
56+
if err != nil {
57+
fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
58+
return err
59+
}
60+
61+
req := &aiplatformpb.PredictRequest{
62+
Endpoint: gemma2Endpoint,
63+
Instances: []*structpb.Value{promptValue},
64+
}
65+
66+
resp, err := client.Predict(ctx, req)
67+
if err != nil {
68+
return err
69+
}
70+
71+
prediction := resp.GetPredictions()
72+
value := prediction[0].GetStringValue()
73+
fmt.Fprintf(w, "%v", value)
74+
75+
return nil
76+
}
77+
78+
// [END generativeaionvertexai_gemma2_predict_gpu]
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package snippets
16+
17+
// [START generativeaionvertexai_gemma2_predict_tpu]
18+
import (
19+
"context"
20+
"fmt"
21+
"io"
22+
23+
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
24+
25+
"google.golang.org/protobuf/types/known/structpb"
26+
)
27+
28+
// predictTPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accelerators.
29+
func predictTPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
30+
ctx := context.Background()
31+
32+
// Note: client can be initialized in the following way:
33+
// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
34+
// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
35+
// if err != nil {
36+
// return fmt.Errorf("unable to create prediction client: %v", err)
37+
// }
38+
// defer client.Close()
39+
40+
gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
41+
prompt := "Why is the sky blue?"
42+
parameters := map[string]interface{}{
43+
"temperature": 0.9,
44+
"maxOutputTokens": 1024,
45+
"topP": 1.0,
46+
"topK": 1,
47+
}
48+
49+
// Encapsulate the prompt in a correct format for TPUs.
50+
// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
51+
promptValue, err := structpb.NewValue(map[string]interface{}{
52+
"prompt": prompt,
53+
"parameters": parameters,
54+
})
55+
if err != nil {
56+
fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
57+
return err
58+
}
59+
60+
req := &aiplatformpb.PredictRequest{
61+
Endpoint: gemma2Endpoint,
62+
Instances: []*structpb.Value{promptValue},
63+
}
64+
65+
resp, err := client.Predict(ctx, req)
66+
if err != nil {
67+
return err
68+
}
69+
70+
prediction := resp.GetPredictions()
71+
value := prediction[0].GetStringValue()
72+
fmt.Fprintf(w, "%v", value)
73+
74+
return nil
75+
}
76+
77+
// [END generativeaionvertexai_gemma2_predict_tpu]

vertexai/gemma2/gemma2_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package snippets
16+
17+
import (
18+
"bytes"
19+
"context"
20+
"strings"
21+
"testing"
22+
23+
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
24+
"github.com/GoogleCloudPlatform/golang-samples/internal/testutil"
25+
"github.com/googleapis/gax-go/v2"
26+
)
27+
28+
func TestPredictGemma2(t *testing.T) {
29+
tc := testutil.SystemTest(t)
30+
31+
projectID := tc.ProjectID
32+
var buf bytes.Buffer
33+
client := PredictionsClient{}
34+
35+
t.Run("GPU predict", func(t *testing.T) {
36+
buf.Reset()
37+
// Mock ID used to check if GPU was called
38+
if err := predictGPU(&buf, client, projectID, GPUEndpointRegion, GPUEndpointID); err != nil {
39+
t.Fatal(err)
40+
}
41+
42+
if got := buf.String(); !strings.Contains(got, "Rayleigh scattering") {
43+
t.Error("generated text content not found in response")
44+
}
45+
})
46+
47+
t.Run("TPU predict", func(t *testing.T) {
48+
buf.Reset()
49+
// Mock ID used to check if TPU was called
50+
if err := predictTPU(&buf, client, projectID, TPUEndpointRegion, TPUEndpointID); err != nil {
51+
t.Fatal(err)
52+
}
53+
54+
if got := buf.String(); !strings.Contains(got, "Rayleigh scattering") {
55+
t.Error("generated text content not found in response")
56+
}
57+
})
58+
}
59+
60+
type PredictClientInterface interface {
61+
Close() error
62+
Predict(ctx context.Context, req *aiplatformpb.PredictRequest, opts ...gax.CallOption) (*aiplatformpb.PredictResponse, error)
63+
}

vertexai/gemma2/go.mod

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
module github.com/GoogleCloudPlatform/golang-samples/gemma2
2+
3+
go 1.21
4+
5+
require (
6+
cloud.google.com/go/aiplatform v1.68.0
7+
github.com/GoogleCloudPlatform/golang-samples v0.0.0-20240918200157-a00ca430a14b
8+
github.com/googleapis/gax-go/v2 v2.13.0
9+
google.golang.org/protobuf v1.34.2
10+
)
11+
12+
require (
13+
cloud.google.com/go v0.115.1 // indirect
14+
cloud.google.com/go/auth v0.9.3 // indirect
15+
cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect
16+
cloud.google.com/go/compute/metadata v0.5.0 // indirect
17+
cloud.google.com/go/iam v1.2.0 // indirect
18+
cloud.google.com/go/longrunning v0.6.0 // indirect
19+
cloud.google.com/go/storage v1.43.0 // indirect
20+
github.com/felixge/httpsnoop v1.0.4 // indirect
21+
github.com/go-logr/logr v1.4.2 // indirect
22+
github.com/go-logr/stdr v1.2.2 // indirect
23+
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
24+
github.com/google/s2a-go v0.1.8 // indirect
25+
github.com/google/uuid v1.6.0 // indirect
26+
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
27+
go.opencensus.io v0.24.0 // indirect
28+
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
29+
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
30+
go.opentelemetry.io/otel v1.29.0 // indirect
31+
go.opentelemetry.io/otel/metric v1.29.0 // indirect
32+
go.opentelemetry.io/otel/trace v1.29.0 // indirect
33+
golang.org/x/crypto v0.27.0 // indirect
34+
golang.org/x/net v0.29.0 // indirect
35+
golang.org/x/oauth2 v0.23.0 // indirect
36+
golang.org/x/sync v0.8.0 // indirect
37+
golang.org/x/sys v0.25.0 // indirect
38+
golang.org/x/text v0.18.0 // indirect
39+
golang.org/x/time v0.6.0 // indirect
40+
google.golang.org/api v0.197.0 // indirect
41+
google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 // indirect
42+
google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed // indirect
43+
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
44+
google.golang.org/grpc v1.66.1 // indirect
45+
)

0 commit comments

Comments
 (0)