Skip to content

Commit 329b912

Browse files
authored
update: return API embedding response (#6)
* update: return API embedding response Originally Embeddings method would return an object which was a result of a conversion to a smaller type which only contained the embeddings. This prevents the clients from accessing the API response fields which is not ideal. This commit reverts to returning deserialized API responses and instead provides an exported func in each package that lets the consumers to convert the object into a higher level embedding type defined in the module root. Signed-off-by: Milos Gajdos <[email protected]> * update: make sure examples dont have empty string inputs Signed-off-by: Milos Gajdos <[email protected]> --------- Signed-off-by: Milos Gajdos <[email protected]>
1 parent 50ba631 commit 329b912

File tree

7 files changed

+136
-85
lines changed

7 files changed

+136
-85
lines changed

cmd/cohere/main.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ var (
1717
)
1818

1919
func init() {
20-
flag.StringVar(&input, "input", "", "input data")
20+
flag.StringVar(&input, "input", "what is life", "input data")
2121
flag.StringVar(&model, "model", string(cohere.EnglishV3), "model name")
2222
flag.StringVar(&truncate, "truncate", string(cohere.NoneTrunc), "truncate type")
2323
flag.StringVar(&inputType, "input-type", string(cohere.ClusteringInput), "input type")
@@ -35,7 +35,12 @@ func main() {
3535
Truncate: cohere.Truncate(truncate),
3636
}
3737

38-
embs, err := c.Embeddings(context.Background(), embReq)
38+
embResp, err := c.Embeddings(context.Background(), embReq)
39+
if err != nil {
40+
log.Fatal(err)
41+
}
42+
43+
embs, err := cohere.ToEmbeddings(embResp)
3944
if err != nil {
4045
log.Fatal(err)
4146
}

cmd/openai/main.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ var (
1616
)
1717

1818
func init() {
19-
flag.StringVar(&input, "input", "", "input data")
19+
flag.StringVar(&input, "input", "what is life", "input data")
2020
flag.StringVar(&model, "model", string(openai.TextAdaV2), "model name")
2121
flag.StringVar(&encoding, "encoding", string(openai.EncodingFloat), "encoding format")
2222
}
@@ -32,7 +32,12 @@ func main() {
3232
EncodingFormat: openai.EncodingFormat(encoding),
3333
}
3434

35-
embs, err := c.Embeddings(context.Background(), embReq)
35+
embResp, err := c.Embeddings(context.Background(), embReq)
36+
if err != nil {
37+
log.Fatal(err)
38+
}
39+
40+
embs, err := openai.ToEmbeddings(embResp)
3641
if err != nil {
3742
log.Fatal(err)
3843
}

cmd/vertexai/main.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,12 @@ func main() {
5353
},
5454
}
5555

56-
embs, err := c.Embeddings(context.Background(), embReq)
56+
embResp, err := c.Embeddings(context.Background(), embReq)
57+
if err != nil {
58+
log.Fatal(err)
59+
}
60+
61+
embs, err := vertexai.ToEmbeddings(embResp)
5762
if err != nil {
5863
log.Fatal(err)
5964
}

cohere/embedding.go

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,11 @@ import (
44
"bytes"
55
"context"
66
"encoding/json"
7-
"io"
87
"net/http"
98
"net/url"
10-
)
119

12-
// Embedding is vector embedding.
13-
type Embedding struct {
14-
Vector []float64 `json:"vector"`
15-
}
10+
"github.com/milosgajdos/go-embeddings"
11+
)
1612

1713
// EmbeddingRequest sent to API endpoint.
1814
type EmbeddingRequest struct {
@@ -38,18 +34,14 @@ type APIVersion struct {
3834
Version string `json:"version"`
3935
}
4036

41-
// toEmbeddings decodes the raw API response,
37+
// ToEmbeddings converts the raw API response,
4238
// parses it into a slice of embeddings and returns it.
43-
func toEmbeddings(r io.Reader) ([]*Embedding, error) {
44-
var resp EmbedddingResponse
45-
if err := json.NewDecoder(r).Decode(&resp); err != nil {
46-
return nil, err
47-
}
48-
embs := make([]*Embedding, 0, len(resp.Embeddings))
49-
for _, e := range resp.Embeddings {
39+
func ToEmbeddings(e *EmbedddingResponse) ([]*embeddings.Embedding, error) {
40+
embs := make([]*embeddings.Embedding, 0, len(e.Embeddings))
41+
for _, e := range e.Embeddings {
5042
floats := make([]float64, len(e))
5143
copy(floats, e)
52-
emb := &Embedding{
44+
emb := &embeddings.Embedding{
5345
Vector: floats,
5446
}
5547
embs = append(embs, emb)
@@ -58,7 +50,7 @@ func toEmbeddings(r io.Reader) ([]*Embedding, error) {
5850
}
5951

6052
// Embeddings returns embeddings for every object in EmbeddingRequest.
61-
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) ([]*Embedding, error) {
53+
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) (*EmbedddingResponse, error) {
6254
u, err := url.Parse(c.baseURL + "/" + c.version + "/embed")
6355
if err != nil {
6456
return nil, err
@@ -82,5 +74,10 @@ func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) ([]*E
8274
}
8375
defer resp.Body.Close()
8476

85-
return toEmbeddings(resp.Body)
77+
e := new(EmbedddingResponse)
78+
if err := json.NewDecoder(resp.Body).Decode(e); err != nil {
79+
return nil, err
80+
}
81+
82+
return e, nil
8683
}

embedding.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package embeddings
2+
3+
// Embedding is vector embedding.
4+
type Embedding struct {
5+
Vector []float64 `json:"vector"`
6+
}
7+
8+
// ToFloat32 returns Embedding verctor as a slice of float32.
9+
func (e Embedding) ToFloat32() []float32 {
10+
floats := make([]float32, 0, len(e.Vector))
11+
for _, f := range e.Vector {
12+
floats = append(floats, float32(f))
13+
}
14+
return floats
15+
}

openai/embedding.go

Lines changed: 73 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,9 @@ import (
1111
"math"
1212
"net/http"
1313
"net/url"
14-
)
15-
16-
// Usage tracks API token usage.
17-
type Usage struct {
18-
PromptTokens int `json:"prompt_tokens"`
19-
TotalTokens int `json:"total_tokens"`
20-
}
2114

22-
// Embedding is openai API embedding.
23-
type Embedding struct {
24-
Object string `json:"object"`
25-
Index int `json:"index"`
26-
Vector []float64 `json:"vector"`
27-
}
15+
"github.com/milosgajdos/go-embeddings"
16+
)
2817

2918
// EmbeddingString is base64 encoded embedding.
3019
type EmbeddingString string
@@ -50,6 +39,27 @@ func (s EmbeddingString) Decode() ([]float64, error) {
5039
return floats, nil
5140
}
5241

42+
// Usage tracks API token usage.
43+
type Usage struct {
44+
PromptTokens int `json:"prompt_tokens"`
45+
TotalTokens int `json:"total_tokens"`
46+
}
47+
48+
// Data stores vector embeddings.
49+
type Data struct {
50+
Object string `json:"object"`
51+
Index int `json:"index"`
52+
Embedding []float64 `json:"embedding"`
53+
}
54+
55+
// EmbeddingResponseGen is the API response.
56+
type EmbeddingResponse struct {
57+
Object string `json:"object"`
58+
Data []Data `json:"data"`
59+
Model Model `json:"model"`
60+
Usage Usage `json:"usage"`
61+
}
62+
5363
// EmbeddingRequest is serialized and sent to the API server.
5464
type EmbeddingRequest struct {
5565
Input any `json:"input"`
@@ -58,64 +68,82 @@ type EmbeddingRequest struct {
5868
EncodingFormat EncodingFormat `json:"encoding_format,omitempty"`
5969
}
6070

61-
// Data stores the raw embeddings.
62-
// It's used when deserializing data from API.
63-
type Data[T any] struct {
71+
// DataGen is a generic struct used for deserializing vector embeddings.
72+
type DataGen[T any] struct {
6473
Object string `json:"object"`
6574
Index int `json:"index"`
6675
Embedding T `json:"embedding"`
6776
}
6877

69-
// EmbeddingResponse is the API response.
70-
type EmbeddingResponse[T any] struct {
71-
Object string `json:"object"`
72-
Data []Data[T] `json:"data"`
73-
Model Model `json:"model"`
74-
Usage Usage `json:"usage"`
78+
// EmbeddingResponseGen is a generic struct used for deserializing API response.
79+
type EmbeddingResponseGen[T any] struct {
80+
Object string `json:"object"`
81+
Data []DataGen[T] `json:"data"`
82+
Model Model `json:"model"`
83+
Usage Usage `json:"usage"`
7584
}
7685

77-
// toEmbeddings decodes the raw API response,
86+
// ToEmbeddings converts the raw API response,
7887
// parses it into a slice of embeddings and returns it.
79-
func toEmbeddings[T any](resp io.Reader) ([]*Embedding, error) {
88+
func ToEmbeddings(e *EmbeddingResponse) ([]*embeddings.Embedding, error) {
89+
embs := make([]*embeddings.Embedding, 0, len(e.Data))
90+
for _, d := range e.Data {
91+
floats := make([]float64, len(d.Embedding))
92+
copy(floats, d.Embedding)
93+
emb := &embeddings.Embedding{
94+
Vector: floats,
95+
}
96+
embs = append(embs, emb)
97+
}
98+
return embs, nil
99+
}
100+
101+
// toEmbeddingResp decodes the raw API response,
102+
// parses it into a slice of embeddings and returns it.
103+
func toEmbeddingResp[T any](resp io.Reader) (*EmbeddingResponse, error) {
80104
data := new(T)
81105
if err := json.NewDecoder(resp).Decode(data); err != nil {
82106
return nil, err
83107
}
84108

85109
switch e := any(data).(type) {
86-
case *EmbeddingResponse[EmbeddingString]:
87-
embs := make([]*Embedding, 0, len(e.Data))
110+
case *EmbeddingResponseGen[EmbeddingString]:
111+
embData := make([]Data, 0, len(e.Data))
88112
for _, d := range e.Data {
89113
floats, err := d.Embedding.Decode()
90114
if err != nil {
91115
return nil, err
92116
}
93-
emb := &Embedding{
94-
Object: d.Object,
95-
Index: d.Index,
96-
Vector: floats,
97-
}
98-
embs = append(embs, emb)
117+
embData = append(embData, Data{
118+
Object: d.Object,
119+
Index: d.Index,
120+
Embedding: floats,
121+
})
99122
}
100-
return embs, nil
101-
case *EmbeddingResponse[[]float64]:
102-
embs := make([]*Embedding, 0, len(e.Data))
123+
return &EmbeddingResponse{
124+
Object: e.Object,
125+
Data: embData,
126+
Model: e.Model,
127+
Usage: e.Usage,
128+
}, nil
129+
case *EmbeddingResponseGen[[]float64]:
130+
embData := make([]Data, 0, len(e.Data))
103131
for _, d := range e.Data {
104-
emb := &Embedding{
105-
Object: d.Object,
106-
Index: d.Index,
107-
Vector: d.Embedding,
108-
}
109-
embs = append(embs, emb)
132+
embData = append(embData, Data(d))
110133
}
111-
return embs, nil
134+
return &EmbeddingResponse{
135+
Object: e.Object,
136+
Data: embData,
137+
Model: e.Model,
138+
Usage: e.Usage,
139+
}, nil
112140
}
113141

114142
return nil, ErrInValidData
115143
}
116144

117145
// Embeddings returns embeddings for every object in EmbeddingRequest.
118-
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) ([]*Embedding, error) {
146+
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) (*EmbeddingResponse, error) {
119147
u, err := url.Parse(c.baseURL + "/" + c.version + "/embeddings")
120148
if err != nil {
121149
return nil, err
@@ -141,9 +169,9 @@ func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) ([]*E
141169

142170
switch embReq.EncodingFormat {
143171
case EncodingBase64:
144-
return toEmbeddings[EmbeddingResponse[EmbeddingString]](resp.Body)
172+
return toEmbeddingResp[EmbeddingResponseGen[EmbeddingString]](resp.Body)
145173
case EncodingFloat:
146-
return toEmbeddings[EmbeddingResponse[[]float64]](resp.Body)
174+
return toEmbeddingResp[EmbeddingResponseGen[[]float64]](resp.Body)
147175
}
148176

149177
return nil, ErrUnsupportedEncoding

vertexai/embedding.go

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,11 @@ import (
44
"bytes"
55
"context"
66
"encoding/json"
7-
"io"
87
"net/http"
98
"net/url"
10-
)
119

12-
// Embedding is cohere API vector embedding.
13-
type Embedding struct {
14-
Vector []float64 `json:"vector"`
15-
}
10+
"github.com/milosgajdos/go-embeddings"
11+
)
1612

1713
// EmbeddingRequest sent to API endpoint.
1814
// https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings#generative-ai-get-text-embedding-drest
@@ -57,18 +53,13 @@ type Statistics struct {
5753
Truncated bool `json:"truncated"`
5854
}
5955

60-
// toEmbeddings decodes the raw API response,
61-
// parses it into a slice of embeddings and returns it.
62-
func toEmbeddings(r io.Reader) ([]*Embedding, error) {
63-
var resp EmbedddingResponse
64-
if err := json.NewDecoder(r).Decode(&resp); err != nil {
65-
return nil, err
66-
}
67-
embs := make([]*Embedding, 0, len(resp.Predictions))
68-
for _, p := range resp.Predictions {
56+
// ToEmbeddings converts the API response to embeddings object.
57+
func ToEmbeddings(e *EmbedddingResponse) ([]*embeddings.Embedding, error) {
58+
embs := make([]*embeddings.Embedding, 0, len(e.Predictions))
59+
for _, p := range e.Predictions {
6960
floats := make([]float64, len(p.Embeddings.Values))
7061
copy(floats, p.Embeddings.Values)
71-
emb := &Embedding{
62+
emb := &embeddings.Embedding{
7263
Vector: floats,
7364
}
7465
embs = append(embs, emb)
@@ -77,7 +68,7 @@ func toEmbeddings(r io.Reader) ([]*Embedding, error) {
7768
}
7869

7970
// Embeddings returns embeddings for every object in EmbeddingRequest.
80-
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) ([]*Embedding, error) {
71+
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) (*EmbedddingResponse, error) {
8172
u, err := url.Parse(c.baseURL + "/" + c.projectID + "/" + ModelURI + "/" + c.modelID + EmbedAction)
8273
if err != nil {
8374
return nil, err
@@ -101,5 +92,10 @@ func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) ([]*E
10192
}
10293
defer resp.Body.Close()
10394

104-
return toEmbeddings(resp.Body)
95+
e := new(EmbedddingResponse)
96+
if err := json.NewDecoder(resp.Body).Decode(e); err != nil {
97+
return nil, err
98+
}
99+
100+
return e, nil
105101
}

0 commit comments

Comments
 (0)