Skip to content

Commit a725a5d

Browse files
authored
add vertexai API client (#4)
* add vertexai API client --------- Signed-off-by: Milos Gajdos <[email protected]>
1 parent 6689e87 commit a725a5d

File tree

9 files changed

+437
-0
lines changed

9 files changed

+437
-0
lines changed

cmd/vertexai/main.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"flag"
6+
"fmt"
7+
"log"
8+
9+
"github.com/milosgajdos/go-embeddings/vertexai"
10+
"golang.org/x/oauth2/google"
11+
)
12+
13+
var (
14+
input string
15+
model string
16+
truncate bool
17+
taskType string
18+
title string
19+
)
20+
21+
func init() {
22+
flag.StringVar(&input, "input", "what is life", "input data")
23+
flag.StringVar(&model, "model", string(vertexai.EmbedGeckoV2), "model name")
24+
flag.BoolVar(&truncate, "truncate", false, "truncate type")
25+
flag.StringVar(&taskType, "task-type", string(vertexai.RetrQueryTask), "task type")
26+
flag.StringVar(&title, "title", "", "title: only relevant for retrival document tasks")
27+
}
28+
29+
func main() {
30+
flag.Parse()
31+
32+
ctx := context.Background()
33+
34+
ts, err := google.DefaultTokenSource(ctx, vertexai.Scopes)
35+
if err != nil {
36+
log.Fatalf("token source: %v", err)
37+
}
38+
39+
c, err := vertexai.NewClient()
40+
if err != nil {
41+
log.Fatal(err)
42+
}
43+
c.WithTokenSrc(ts)
44+
c.WithModelID(model)
45+
46+
embReq := &vertexai.EmbeddingRequest{
47+
Instances: []vertexai.Instance{
48+
{
49+
Content: input,
50+
TaskType: vertexai.TaskType(taskType),
51+
Title: title,
52+
},
53+
},
54+
Params: vertexai.Params{
55+
AutoTruncate: truncate,
56+
},
57+
}
58+
59+
embs, err := c.Embeddings(context.Background(), embReq)
60+
if err != nil {
61+
log.Fatal(err)
62+
}
63+
64+
fmt.Printf("got %d embeddings", len(embs))
65+
}

cmd/vertexai/vertexai

7.69 MB
Binary file not shown.

go.mod

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
11
module github.com/milosgajdos/go-embeddings
22

33
go 1.20
4+
5+
require golang.org/x/oauth2 v0.15.0
6+
7+
require (
8+
cloud.google.com/go/compute v1.20.1 // indirect
9+
cloud.google.com/go/compute/metadata v0.2.3 // indirect
10+
github.com/golang/protobuf v1.5.3 // indirect
11+
golang.org/x/net v0.19.0 // indirect
12+
google.golang.org/appengine v1.6.7 // indirect
13+
google.golang.org/protobuf v1.31.0 // indirect
14+
)

go.sum

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
cloud.google.com/go/compute v1.20.1 h1:6aKEtlUiwEpJzM001l0yFkpXmUVXaN8W+fbkb2AZNbg=
2+
cloud.google.com/go/compute v1.20.1/go.mod h1:4tCnrn48xsqlwSAiLf1HXMQk8CONslYbdiEZc9FEIbM=
3+
cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY=
4+
cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA=
5+
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
6+
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
7+
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
8+
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
9+
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
10+
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
11+
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
12+
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
13+
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
14+
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
15+
golang.org/x/oauth2 v0.15.0 h1:s8pnnxNVzjWyrvYdFUQq5llS1PX2zhPXmccZv99h7uQ=
16+
golang.org/x/oauth2 v0.15.0/go.mod h1:q48ptWNTY5XWf+JNten23lcvHpLJ0ZSxF5ttTHKVCAM=
17+
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
18+
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
19+
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
20+
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
21+
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
22+
google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
23+
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
24+
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
25+
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
26+
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
27+
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=

vertexai/client.go

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
package vertexai
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"fmt"
8+
"io"
9+
"net/http"
10+
"os"
11+
12+
"golang.org/x/oauth2"
13+
)
14+
15+
const (
16+
// BaseURL is Vertex AI HTTP API base URL
17+
BaseURL = "https://us-central1-aiplatform.googleapis.com/v1/projects"
18+
// ModelURI is Vertex AI HTTP API model URI.
19+
ModelURI = "locations/us-central1/publishers/google/models"
20+
// EmbedAction is embedding API action.
21+
EmbedAction = ":predict"
22+
)
23+
24+
// Client is vertex AI HTTP API client.
25+
type Client struct {
26+
token string
27+
tokenSrc oauth2.TokenSource
28+
projectID string
29+
modelID string
30+
baseURL string
31+
hc *http.Client
32+
}
33+
34+
// NewClient creates a new HTTP client and returns it.
35+
// It reads the Google API token from VERTEXAI_TOKEN env var
36+
// just like the project ID is read from GOOGLE_PROJECT_ID env var
37+
// and uses the default Go http.Client.
38+
// You can override the default options by using the
39+
// client methods.
40+
func NewClient() (*Client, error) {
41+
return &Client{
42+
token: os.Getenv("VERTEXAI_TOKEN"),
43+
modelID: os.Getenv("VERTEXAI_MODEL_ID"),
44+
projectID: os.Getenv("GOOGLE_PROJECT_ID"),
45+
baseURL: BaseURL,
46+
hc: &http.Client{},
47+
}, nil
48+
}
49+
50+
// WithToken sets the API key.
51+
func (c *Client) WithToken(token string) *Client {
52+
c.token = token
53+
return c
54+
}
55+
56+
// WithTokenSrc sets the API token source.
57+
func (c *Client) WithTokenSrc(ts oauth2.TokenSource) *Client {
58+
c.tokenSrc = ts
59+
return c
60+
}
61+
62+
// WithProjectID sets the project ID.
63+
func (c *Client) WithProjectID(id string) *Client {
64+
c.projectID = id
65+
return c
66+
}
67+
68+
// WithModelID sets the model ID.
69+
func (c *Client) WithModelID(id string) *Client {
70+
c.modelID = id
71+
return c
72+
}
73+
74+
// WithBaseURL sets the API base URL.
75+
func (c *Client) WithBaseURL(baseURL string) *Client {
76+
c.baseURL = baseURL
77+
return c
78+
}
79+
80+
// WithHTTPClient sets the HTTP client.
81+
func (c *Client) WithHTTPClient(httpClient *http.Client) *Client {
82+
c.hc = httpClient
83+
return c
84+
}
85+
86+
// ReqOption is http requestion functional option.
87+
type ReqOption func(*http.Request)
88+
89+
// WithSetHeader sets the header key to value val.
90+
func WithSetHeader(key, val string) ReqOption {
91+
return func(req *http.Request) {
92+
if req.Header == nil {
93+
req.Header = make(http.Header)
94+
}
95+
req.Header.Set(key, val)
96+
}
97+
}
98+
99+
// WithAddHeader adds the val to key header.
100+
func WithAddHeader(key, val string) ReqOption {
101+
return func(req *http.Request) {
102+
if req.Header == nil {
103+
req.Header = make(http.Header)
104+
}
105+
req.Header.Add(key, val)
106+
}
107+
}
108+
109+
func (c *Client) newRequest(ctx context.Context, method, url string, body io.Reader, opts ...ReqOption) (*http.Request, error) {
110+
if ctx == nil {
111+
ctx = context.Background()
112+
}
113+
if body == nil {
114+
body = &bytes.Reader{}
115+
}
116+
117+
req, err := http.NewRequestWithContext(ctx, method, url, body)
118+
if err != nil {
119+
return nil, err
120+
}
121+
122+
for _, setOption := range opts {
123+
setOption(req)
124+
}
125+
126+
if c.token == "" {
127+
var err error
128+
c.token, err = GetToken(c.tokenSrc)
129+
if err != nil {
130+
return nil, err
131+
}
132+
}
133+
134+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token))
135+
req.Header.Set("Accept", "application/json; charset=utf-8")
136+
if body != nil {
137+
// if no content-type is specified we default to json
138+
if ct := req.Header.Get("Content-Type"); len(ct) == 0 {
139+
req.Header.Set("Content-Type", "application/json; charset=utf-8")
140+
}
141+
}
142+
143+
return req, nil
144+
}
145+
146+
func (c *Client) doRequest(req *http.Request) (*http.Response, error) {
147+
resp, err := c.hc.Do(req)
148+
if err != nil {
149+
return nil, err
150+
}
151+
if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusBadRequest {
152+
return resp, nil
153+
}
154+
defer resp.Body.Close()
155+
156+
var apiErr APIError
157+
if err := json.NewDecoder(resp.Body).Decode(&apiErr); err != nil {
158+
return nil, err
159+
}
160+
161+
return nil, apiErr
162+
}

vertexai/embedding.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package vertexai
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"io"
8+
"net/http"
9+
"net/url"
10+
)
11+
12+
// Embedding is cohere API vector embedding.
13+
type Embedding struct {
14+
Vector []float64 `json:"vector"`
15+
}
16+
17+
// EmbeddingRequest sent to API endpoint.
18+
// https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings#generative-ai-get-text-embedding-drest
19+
type EmbeddingRequest struct {
20+
Instances []Instance `json:"instances"`
21+
Params Params `json:"parameters"`
22+
}
23+
24+
// NOTE: Title is only valid with TaskType set to RetrDocTask
25+
// https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings#api_changes_to_models_released_on_or_after_august_2023
26+
type Instance struct {
27+
TaskType TaskType `json:"task_type"`
28+
Title string `json:"title,omitempty"`
29+
Content string `json:"content"`
30+
}
31+
32+
// Params are additional API request parameters passed via body.
33+
type Params struct {
34+
// If set to false, text that exceeds the token limit (3.072)
35+
// causes the request to fail. The default value is true
36+
AutoTruncate bool `json:"autoTruncate"`
37+
}
38+
39+
// EmbedddingResponse received from API endpoint.
40+
// https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-embeddings#response_body
41+
type EmbedddingResponse struct {
42+
Predictions []Predictions `json:"predictions"`
43+
Metadata map[string]any `json:"metadata"`
44+
}
45+
46+
// Predictions is the generated response
47+
type Predictions struct {
48+
Embeddings struct {
49+
Values []float64 `json:"values"`
50+
Statistics Statistics `json:"statistics"`
51+
} `json:"embeddings"`
52+
}
53+
54+
// Statistics define the statistics for a text embedding
55+
type Statistics struct {
56+
TokenCount int `json:"token_count"`
57+
Truncated bool `json:"truncated"`
58+
}
59+
60+
func ToEmbeddings(r io.Reader) ([]*Embedding, error) {
61+
var resp EmbedddingResponse
62+
if err := json.NewDecoder(r).Decode(&resp); err != nil {
63+
return nil, err
64+
}
65+
embs := make([]*Embedding, 0, len(resp.Predictions))
66+
for _, p := range resp.Predictions {
67+
floats := make([]float64, len(p.Embeddings.Values))
68+
copy(floats, p.Embeddings.Values)
69+
emb := &Embedding{
70+
Vector: floats,
71+
}
72+
embs = append(embs, emb)
73+
}
74+
return embs, nil
75+
}
76+
77+
// Embeddings returns embeddings for every object in EmbeddingRequest.
78+
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) ([]*Embedding, error) {
79+
u, err := url.Parse(c.baseURL + "/" + c.projectID + "/" + ModelURI + "/" + c.modelID + EmbedAction)
80+
if err != nil {
81+
return nil, err
82+
}
83+
84+
var body = &bytes.Buffer{}
85+
enc := json.NewEncoder(body)
86+
enc.SetEscapeHTML(false)
87+
if err := enc.Encode(embReq); err != nil {
88+
return nil, err
89+
}
90+
91+
req, err := c.newRequest(ctx, http.MethodPost, u.String(), body)
92+
if err != nil {
93+
return nil, err
94+
}
95+
96+
resp, err := c.doRequest(req)
97+
if err != nil {
98+
return nil, err
99+
}
100+
defer resp.Body.Close()
101+
102+
return ToEmbeddings(resp.Body)
103+
}

vertexai/error.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package vertexai
2+
3+
import "encoding/json"
4+
5+
type APIError struct {
6+
RespError struct {
7+
Code int `json:"code"`
8+
Message string `json:"message"`
9+
Status string `json:"status"`
10+
} `json:"error"`
11+
}
12+
13+
func (e APIError) Error() string {
14+
b, err := json.Marshal(e)
15+
if err != nil {
16+
return "unknown error"
17+
}
18+
return string(b)
19+
}

0 commit comments

Comments
 (0)