Skip to content

Commit 7e86d0e

Browse files
committed
feat: gemma2 samples with accelerated TPU and GPU
1 parent a00ca43 commit 7e86d0e

File tree

7 files changed

+493
-0
lines changed

7 files changed

+493
-0
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
/auth/ @GoogleCloudPlatform/go-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers @GoogleCloudPlatform/dee-infra @GoogleCloudPlatform/googleapis-auth
6161
/batch/ @GoogleCloudPlatform/go-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers @GoogleCloudPlatform/dee-infra
6262
/compute/ @GoogleCloudPlatform/go-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers @GoogleCloudPlatform/dee-infra
63+
/gemma2 @GoogleCloudPlatform/go-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers @GoogleCloudPlatform/dee-infra
6364
/iam/ @GoogleCloudPlatform/go-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers @GoogleCloudPlatform/dee-infra
6465
/iap/ @GoogleCloudPlatform/go-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers @GoogleCloudPlatform/dee-infra
6566
/kms/ @GoogleCloudPlatform/go-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers @GoogleCloudPlatform/dee-infra

gemma2/gemma2_predict_gpu.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package snippets
16+
17+
// [START generativeaionvertexai_gemma2_predict_gpu]
18+
import (
19+
"context"
20+
"fmt"
21+
"io"
22+
23+
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
24+
gax "github.com/googleapis/gax-go/v2"
25+
26+
"google.golang.org/protobuf/types/known/structpb"
27+
)
28+
29+
type ClientInterface interface {
30+
Close() error
31+
Predict(ctx context.Context, req *aiplatformpb.PredictRequest, opts ...gax.CallOption) (*aiplatformpb.PredictResponse, error)
32+
}
33+
34+
// predictGPU demopnstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with GPU accellerators.
35+
func predictGPU(w io.Writer, client ClientInterface, projectID, location, endpointID string) error {
36+
ctx := context.Background()
37+
38+
// Note: client can be initialised in the following way:
39+
// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
40+
// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
41+
// if err != nil {
42+
// return fmt.Errorf("unable to create prediction client: %v", err)
43+
// }
44+
// defer client.Close()
45+
46+
gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
47+
prompt := "Why is the sky blue?"
48+
parameters := map[string]interface{}{
49+
"temperature": 0.9,
50+
"maxOutputTokens": 1024,
51+
"topP": 1.0,
52+
"topK": 1,
53+
}
54+
55+
// Encapsulate the prompt in a correct format for TPUs.
56+
// Pay attention that prompt should be set in "inputs" field.
57+
promptValue, err := structpb.NewValue(map[string]interface{}{
58+
"inputs": prompt,
59+
"parameters": parameters,
60+
})
61+
if err != nil {
62+
fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
63+
return err
64+
}
65+
66+
req := &aiplatformpb.PredictRequest{
67+
Endpoint: gemma2Endpoint,
68+
Instances: []*structpb.Value{promptValue},
69+
}
70+
71+
resp, err := client.Predict(ctx, req)
72+
if err != nil {
73+
return err
74+
}
75+
76+
prediction := resp.GetPredictions()
77+
value := prediction[0].GetStringValue()
78+
fmt.Fprintf(w, "%v", value)
79+
80+
return nil
81+
}
82+
83+
// [END generativeaionvertexai_gemma2_predict_gpu]

gemma2/gemma2_predict_tpu.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package snippets
16+
17+
// [START generativeaionvertexai_gemma2_predict_tpu]
18+
import (
19+
"context"
20+
"fmt"
21+
"io"
22+
23+
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
24+
gax "github.com/googleapis/gax-go/v2"
25+
26+
"google.golang.org/protobuf/types/known/structpb"
27+
)
28+
29+
type PredictClientInterface interface {
30+
Close() error
31+
Predict(ctx context.Context, req *aiplatformpb.PredictRequest, opts ...gax.CallOption) (*aiplatformpb.PredictResponse, error)
32+
}
33+
34+
// predictTPU demopnstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accellerators.
35+
func predictTPU(w io.Writer, client PredictClientInterface, projectID, location, endpointID string) error {
36+
ctx := context.Background()
37+
38+
// Note: client can be initialised in the following way:
39+
// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
40+
// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
41+
// if err != nil {
42+
// return fmt.Errorf("unable to create prediction client: %v", err)
43+
// }
44+
// defer client.Close()
45+
46+
gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
47+
prompt := "Why is the sky blue?"
48+
parameters := map[string]interface{}{
49+
"temperature": 0.9,
50+
"maxOutputTokens": 1024,
51+
"topP": 1.0,
52+
"topK": 1,
53+
}
54+
55+
// Encapsulate the prompt in a correct format for TPUs.
56+
promptValue, err := structpb.NewValue(map[string]interface{}{
57+
"prompt": prompt,
58+
"parameters": parameters,
59+
})
60+
if err != nil {
61+
fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
62+
return err
63+
}
64+
65+
req := &aiplatformpb.PredictRequest{
66+
Endpoint: gemma2Endpoint,
67+
Instances: []*structpb.Value{promptValue},
68+
}
69+
70+
resp, err := client.Predict(ctx, req)
71+
if err != nil {
72+
return err
73+
}
74+
75+
prediction := resp.GetPredictions()
76+
value := prediction[0].GetStringValue()
77+
fmt.Fprintf(w, "%v", value)
78+
79+
return nil
80+
}
81+
82+
// [END generativeaionvertexai_gemma2_predict_tpu]

gemma2/gemma2_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package snippets
16+
17+
import (
18+
"bytes"
19+
"strings"
20+
"testing"
21+
22+
"github.com/GoogleCloudPlatform/golang-samples/internal/testutil"
23+
)
24+
25+
func TestPredictGPU(t *testing.T) {
26+
tc := testutil.SystemTest(t)
27+
28+
projectID := tc.ProjectID
29+
var buf bytes.Buffer
30+
client := PredictionsClient{}
31+
32+
t.Run("GPU predict", func(t *testing.T) {
33+
buf.Reset()
34+
// Mock ID used to check if GPU was called
35+
endpointID := "123456789"
36+
location := "us-east4"
37+
if err := predictGPU(&buf, client, projectID, location, endpointID); err != nil {
38+
t.Fatal(err)
39+
}
40+
41+
if got := buf.String(); !strings.Contains(got, "Rayleigh scattering") {
42+
t.Error("generated text content not found in response")
43+
}
44+
})
45+
46+
t.Run("TPU predict", func(t *testing.T) {
47+
buf.Reset()
48+
// Mock ID used to check if TPU was called
49+
endpointID := "123456789"
50+
location := "us-west1"
51+
if err := predictTPU(&buf, client, projectID, location, endpointID); err != nil {
52+
t.Fatal(err)
53+
}
54+
55+
if got := buf.String(); !strings.Contains(got, "Rayleigh scattering") {
56+
t.Error("generated text content not found in response")
57+
}
58+
})
59+
}

gemma2/go.mod

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
module github.com/GoogleCloudPlatform/golang-samples/gemma2
2+
3+
go 1.21
4+
5+
require (
6+
cloud.google.com/go/aiplatform v1.68.0
7+
github.com/GoogleCloudPlatform/golang-samples v0.0.0-20240918200157-a00ca430a14b
8+
github.com/googleapis/gax-go/v2 v2.13.0
9+
google.golang.org/protobuf v1.34.2
10+
)
11+
12+
require (
13+
cloud.google.com/go v0.115.1 // indirect
14+
cloud.google.com/go/auth v0.9.3 // indirect
15+
cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect
16+
cloud.google.com/go/compute/metadata v0.5.0 // indirect
17+
cloud.google.com/go/iam v1.2.0 // indirect
18+
cloud.google.com/go/longrunning v0.6.0 // indirect
19+
cloud.google.com/go/storage v1.43.0 // indirect
20+
github.com/felixge/httpsnoop v1.0.4 // indirect
21+
github.com/go-logr/logr v1.4.2 // indirect
22+
github.com/go-logr/stdr v1.2.2 // indirect
23+
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
24+
github.com/google/s2a-go v0.1.8 // indirect
25+
github.com/google/uuid v1.6.0 // indirect
26+
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
27+
go.opencensus.io v0.24.0 // indirect
28+
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
29+
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
30+
go.opentelemetry.io/otel v1.29.0 // indirect
31+
go.opentelemetry.io/otel/metric v1.29.0 // indirect
32+
go.opentelemetry.io/otel/trace v1.29.0 // indirect
33+
golang.org/x/crypto v0.27.0 // indirect
34+
golang.org/x/net v0.29.0 // indirect
35+
golang.org/x/oauth2 v0.23.0 // indirect
36+
golang.org/x/sync v0.8.0 // indirect
37+
golang.org/x/sys v0.25.0 // indirect
38+
golang.org/x/text v0.18.0 // indirect
39+
golang.org/x/time v0.6.0 // indirect
40+
google.golang.org/api v0.197.0 // indirect
41+
google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 // indirect
42+
google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed // indirect
43+
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
44+
google.golang.org/grpc v1.66.1 // indirect
45+
)

0 commit comments

Comments
 (0)