Skip to content

Commit c4902dd

Browse files
authored
feat: batch predict code and text generation (#4417)
* feat: batch predict code and text generation * correct go version
1 parent 7c4f707 commit c4902dd

File tree

5 files changed

+474
-0
lines changed

5 files changed

+474
-0
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package batchpredict
16+
17+
// [START generativeaionvertexai_batch_code_predict]
18+
import (
19+
"context"
20+
"fmt"
21+
"io"
22+
23+
aiplatform "cloud.google.com/go/aiplatform/apiv1"
24+
aiplatformpb "cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
25+
"google.golang.org/api/option"
26+
"google.golang.org/protobuf/types/known/structpb"
27+
)
28+
29+
// batchCodePredict perform batch code prediction using a pre-trained code generation model
30+
func batchCodePredict(w io.Writer, projectID, location, name, outputURI string, inputURIs []string) error {
31+
// inputURI := []string{"gs://cloud-samples-data/batch/prompt_for_batch_code_predict.jsonl"}
32+
// outputURI: existing template path. Following formats are allowed:
33+
// - gs://BUCKET_NAME/DIRECTORY/
34+
// - bq://project_name.llm_dataset
35+
36+
ctx := context.Background()
37+
apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
38+
// Pretrained code model
39+
model := "publishers/google/models/code-bison"
40+
parameters := map[string]interface{}{
41+
"temperature": 0.2,
42+
"maxOutputTokens": 200,
43+
}
44+
parametersValue, err := structpb.NewValue(parameters)
45+
if err != nil {
46+
fmt.Fprintf(w, "unable to convert parameters to Value: %v", err)
47+
return err
48+
}
49+
50+
client, err := aiplatform.NewJobClient(ctx, option.WithEndpoint(apiEndpoint))
51+
if err != nil {
52+
return err
53+
}
54+
55+
req := &aiplatformpb.CreateBatchPredictionJobRequest{
56+
Parent: fmt.Sprintf("projects/%s/locations/%s", projectID, location),
57+
BatchPredictionJob: &aiplatformpb.BatchPredictionJob{
58+
DisplayName: name,
59+
Model: model,
60+
ModelParameters: parametersValue,
61+
InputConfig: &aiplatformpb.BatchPredictionJob_InputConfig{
62+
Source: &aiplatformpb.BatchPredictionJob_InputConfig_GcsSource{
63+
GcsSource: &aiplatformpb.GcsSource{
64+
Uris: inputURIs,
65+
},
66+
},
67+
// List of supported formarts: https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#model
68+
InstancesFormat: "jsonl",
69+
},
70+
OutputConfig: &aiplatformpb.BatchPredictionJob_OutputConfig{
71+
Destination: &aiplatformpb.BatchPredictionJob_OutputConfig_GcsDestination{
72+
GcsDestination: &aiplatformpb.GcsDestination{
73+
OutputUriPrefix: outputURI,
74+
},
75+
},
76+
// List of supported formarts: https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#model
77+
PredictionsFormat: "jsonl",
78+
},
79+
},
80+
}
81+
82+
job, err := client.CreateBatchPredictionJob(ctx, req)
83+
if err != nil {
84+
return err
85+
}
86+
fmt.Fprint(w, job.GetDisplayName())
87+
88+
return nil
89+
}
90+
91+
// [END generativeaionvertexai_batch_code_predict]
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package batchpredict
16+
17+
import (
18+
"bytes"
19+
"context"
20+
"fmt"
21+
"math/rand"
22+
"testing"
23+
"time"
24+
25+
"github.com/GoogleCloudPlatform/golang-samples/internal/testutil"
26+
)
27+
28+
func TestBatchPredict(t *testing.T) {
29+
tc := testutil.SystemTest(t)
30+
var buf bytes.Buffer
31+
var r *rand.Rand = rand.New(
32+
rand.NewSource(time.Now().UnixNano()))
33+
34+
ctx := context.Background()
35+
bucketName := testutil.TestBucket(ctx, t, tc.ProjectID, "golang-samples-predict")
36+
location := "us-central1"
37+
outputURI := fmt.Sprintf("gs://%s/", bucketName)
38+
39+
t.Run("code predict", func(t *testing.T) {
40+
buf.Reset()
41+
name := fmt.Sprintf("test-job-go-batch-%v-%v", time.Now().Format("2006-01-02"), r.Int())
42+
inputURIs := []string{"gs://cloud-samples-data/batch/prompt_for_batch_code_predict.jsonl"}
43+
err := batchCodePredict(&buf, tc.ProjectID, location, name, outputURI, inputURIs)
44+
if err != nil {
45+
t.Error(err)
46+
}
47+
48+
output := buf.String()
49+
if output != name {
50+
t.Errorf("job name doesn't match. Got: %s, want: %s", output, name)
51+
}
52+
})
53+
54+
t.Run("text predict", func(t *testing.T) {
55+
buf.Reset()
56+
name := fmt.Sprintf("test-job-go-batch-%v-%v", time.Now().Format("2006-01-02"), r.Int())
57+
inputURIs := []string{"gs://cloud-samples-data/batch/prompt_for_batch_text_predict.jsonl"}
58+
err := batchTextPredict(&buf, tc.ProjectID, location, name, outputURI, inputURIs)
59+
if err != nil {
60+
t.Error(err)
61+
}
62+
63+
output := buf.String()
64+
if output != name {
65+
t.Errorf("job name doesn't match. Got: %s, want: %s", output, name)
66+
}
67+
})
68+
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package batchpredict
16+
17+
// [START generativeaionvertexai_batch_text_predict]
18+
import (
19+
"context"
20+
"fmt"
21+
"io"
22+
23+
aiplatform "cloud.google.com/go/aiplatform/apiv1"
24+
aiplatformpb "cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
25+
"google.golang.org/api/option"
26+
"google.golang.org/protobuf/types/known/structpb"
27+
)
28+
29+
// batchTextPredict perform batch text prediction using a pre-trained text generation model
30+
func batchTextPredict(w io.Writer, projectID, location, name, outputURI string, inputURIs []string) error {
31+
// inputURI := []string{"gs://cloud-samples-data/batch/prompt_for_batch_text_predict.jsonl"}
32+
// outputURI: existing template path. Following formats are allowed:
33+
// - gs://BUCKET_NAME/DIRECTORY/
34+
// - bq://project_name.llm_dataset
35+
36+
ctx := context.Background()
37+
apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
38+
// Pretrained text model
39+
model := "publishers/google/models/text-bison"
40+
parameters := map[string]interface{}{
41+
"temperature": 0.2,
42+
"maxOutputTokens": 200,
43+
"topP": 0.95,
44+
"topK": 40,
45+
}
46+
47+
parametersValue, err := structpb.NewValue(parameters)
48+
if err != nil {
49+
fmt.Fprintf(w, "unable to convert parameters to Value: %v", err)
50+
return err
51+
}
52+
53+
client, err := aiplatform.NewJobClient(ctx, option.WithEndpoint(apiEndpoint))
54+
if err != nil {
55+
return err
56+
}
57+
58+
req := &aiplatformpb.CreateBatchPredictionJobRequest{
59+
Parent: fmt.Sprintf("projects/%s/locations/%s", projectID, location),
60+
BatchPredictionJob: &aiplatformpb.BatchPredictionJob{
61+
DisplayName: name,
62+
Model: model,
63+
ModelParameters: parametersValue,
64+
InputConfig: &aiplatformpb.BatchPredictionJob_InputConfig{
65+
Source: &aiplatformpb.BatchPredictionJob_InputConfig_GcsSource{
66+
GcsSource: &aiplatformpb.GcsSource{
67+
Uris: inputURIs,
68+
},
69+
},
70+
// List of supported formarts: https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#model
71+
InstancesFormat: "jsonl",
72+
},
73+
OutputConfig: &aiplatformpb.BatchPredictionJob_OutputConfig{
74+
Destination: &aiplatformpb.BatchPredictionJob_OutputConfig_GcsDestination{
75+
GcsDestination: &aiplatformpb.GcsDestination{
76+
OutputUriPrefix: outputURI,
77+
},
78+
},
79+
// List of supported formarts: https://cloud.google.com/vertex-ai/docs/reference/rpc/google.cloud.aiplatform.v1#model
80+
PredictionsFormat: "jsonl",
81+
},
82+
},
83+
}
84+
85+
job, err := client.CreateBatchPredictionJob(ctx, req)
86+
if err != nil {
87+
return err
88+
}
89+
fmt.Fprint(w, job.GetDisplayName())
90+
91+
return nil
92+
}
93+
94+
// [END generativeaionvertexai_batch_text_predict]

vertexai/batch-predict/go.mod

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
module github.com/GoogleCloudPlatform/golang-samples/vertexai/batch-predict
2+
3+
go 1.21
4+
5+
require (
6+
cloud.google.com/go/aiplatform v1.68.0
7+
github.com/GoogleCloudPlatform/golang-samples v0.0.0-20241001164912-66760d064c5e
8+
google.golang.org/api v0.199.0
9+
google.golang.org/protobuf v1.34.2
10+
)
11+
12+
require (
13+
cloud.google.com/go v0.115.1 // indirect
14+
cloud.google.com/go/auth v0.9.5 // indirect
15+
cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect
16+
cloud.google.com/go/compute/metadata v0.5.2 // indirect
17+
cloud.google.com/go/iam v1.2.0 // indirect
18+
cloud.google.com/go/longrunning v0.6.0 // indirect
19+
cloud.google.com/go/storage v1.43.0 // indirect
20+
github.com/felixge/httpsnoop v1.0.4 // indirect
21+
github.com/go-logr/logr v1.4.2 // indirect
22+
github.com/go-logr/stdr v1.2.2 // indirect
23+
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
24+
github.com/google/s2a-go v0.1.8 // indirect
25+
github.com/google/uuid v1.6.0 // indirect
26+
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
27+
github.com/googleapis/gax-go/v2 v2.13.0 // indirect
28+
go.opencensus.io v0.24.0 // indirect
29+
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
30+
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
31+
go.opentelemetry.io/otel v1.29.0 // indirect
32+
go.opentelemetry.io/otel/metric v1.29.0 // indirect
33+
go.opentelemetry.io/otel/trace v1.29.0 // indirect
34+
golang.org/x/crypto v0.27.0 // indirect
35+
golang.org/x/net v0.29.0 // indirect
36+
golang.org/x/oauth2 v0.23.0 // indirect
37+
golang.org/x/sync v0.8.0 // indirect
38+
golang.org/x/sys v0.25.0 // indirect
39+
golang.org/x/text v0.18.0 // indirect
40+
golang.org/x/time v0.6.0 // indirect
41+
google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 // indirect
42+
google.golang.org/genproto/googleapis/api v0.0.0-20240827150818-7e3bb234dfed // indirect
43+
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect
44+
google.golang.org/grpc v1.67.0 // indirect
45+
)

0 commit comments

Comments
 (0)