@@ -11,20 +11,9 @@ import (
1111 "math"
1212 "net/http"
1313 "net/url"
14- )
15-
16- // Usage tracks API token usage.
17- type Usage struct {
18- PromptTokens int `json:"prompt_tokens"`
19- TotalTokens int `json:"total_tokens"`
20- }
2114
22- // Embedding is openai API embedding.
23- type Embedding struct {
24- Object string `json:"object"`
25- Index int `json:"index"`
26- Vector []float64 `json:"vector"`
27- }
15+ "github.com/milosgajdos/go-embeddings"
16+ )
2817
2918// EmbeddingString is base64 encoded embedding.
3019type EmbeddingString string
@@ -50,6 +39,27 @@ func (s EmbeddingString) Decode() ([]float64, error) {
5039 return floats , nil
5140}
5241
42+ // Usage tracks API token usage.
43+ type Usage struct {
44+ PromptTokens int `json:"prompt_tokens"`
45+ TotalTokens int `json:"total_tokens"`
46+ }
47+
48+ // Data stores vector embeddings.
49+ type Data struct {
50+ Object string `json:"object"`
51+ Index int `json:"index"`
52+ Embedding []float64 `json:"embedding"`
53+ }
54+
55+ // EmbeddingResponseGen is the API response.
56+ type EmbeddingResponse struct {
57+ Object string `json:"object"`
58+ Data []Data `json:"data"`
59+ Model Model `json:"model"`
60+ Usage Usage `json:"usage"`
61+ }
62+
5363// EmbeddingRequest is serialized and sent to the API server.
5464type EmbeddingRequest struct {
5565 Input any `json:"input"`
@@ -58,64 +68,82 @@ type EmbeddingRequest struct {
5868 EncodingFormat EncodingFormat `json:"encoding_format,omitempty"`
5969}
6070
61- // Data stores the raw embeddings.
62- // It's used when deserializing data from API.
63- type Data [T any ] struct {
71+ // DataGen is a generic struct used for deserializing vector embeddings.
72+ type DataGen [T any ] struct {
6473 Object string `json:"object"`
6574 Index int `json:"index"`
6675 Embedding T `json:"embedding"`
6776}
6877
69- // EmbeddingResponse is the API response.
70- type EmbeddingResponse [T any ] struct {
71- Object string `json:"object"`
72- Data []Data [T ] `json:"data"`
73- Model Model `json:"model"`
74- Usage Usage `json:"usage"`
78+ // EmbeddingResponseGen is a generic struct used for deserializing API response.
79+ type EmbeddingResponseGen [T any ] struct {
80+ Object string `json:"object"`
81+ Data []DataGen [T ] `json:"data"`
82+ Model Model `json:"model"`
83+ Usage Usage `json:"usage"`
7584}
7685
77- // toEmbeddings decodes the raw API response,
86+ // ToEmbeddings converts the raw API response,
7887// parses it into a slice of embeddings and returns it.
79- func toEmbeddings [T any ](resp io.Reader ) ([]* Embedding , error ) {
88+ func ToEmbeddings (e * EmbeddingResponse ) ([]* embeddings.Embedding , error ) {
89+ embs := make ([]* embeddings.Embedding , 0 , len (e .Data ))
90+ for _ , d := range e .Data {
91+ floats := make ([]float64 , len (d .Embedding ))
92+ copy (floats , d .Embedding )
93+ emb := & embeddings.Embedding {
94+ Vector : floats ,
95+ }
96+ embs = append (embs , emb )
97+ }
98+ return embs , nil
99+ }
100+
101+ // toEmbeddingResp decodes the raw API response,
102+ // parses it into a slice of embeddings and returns it.
103+ func toEmbeddingResp [T any ](resp io.Reader ) (* EmbeddingResponse , error ) {
80104 data := new (T )
81105 if err := json .NewDecoder (resp ).Decode (data ); err != nil {
82106 return nil , err
83107 }
84108
85109 switch e := any (data ).(type ) {
86- case * EmbeddingResponse [EmbeddingString ]:
87- embs := make ([]* Embedding , 0 , len (e .Data ))
110+ case * EmbeddingResponseGen [EmbeddingString ]:
111+ embData := make ([]Data , 0 , len (e .Data ))
88112 for _ , d := range e .Data {
89113 floats , err := d .Embedding .Decode ()
90114 if err != nil {
91115 return nil , err
92116 }
93- emb := & Embedding {
94- Object : d .Object ,
95- Index : d .Index ,
96- Vector : floats ,
97- }
98- embs = append (embs , emb )
117+ embData = append (embData , Data {
118+ Object : d .Object ,
119+ Index : d .Index ,
120+ Embedding : floats ,
121+ })
99122 }
100- return embs , nil
101- case * EmbeddingResponse [[]float64 ]:
102- embs := make ([]* Embedding , 0 , len (e .Data ))
123+ return & EmbeddingResponse {
124+ Object : e .Object ,
125+ Data : embData ,
126+ Model : e .Model ,
127+ Usage : e .Usage ,
128+ }, nil
129+ case * EmbeddingResponseGen [[]float64 ]:
130+ embData := make ([]Data , 0 , len (e .Data ))
103131 for _ , d := range e .Data {
104- emb := & Embedding {
105- Object : d .Object ,
106- Index : d .Index ,
107- Vector : d .Embedding ,
108- }
109- embs = append (embs , emb )
132+ embData = append (embData , Data (d ))
110133 }
111- return embs , nil
134+ return & EmbeddingResponse {
135+ Object : e .Object ,
136+ Data : embData ,
137+ Model : e .Model ,
138+ Usage : e .Usage ,
139+ }, nil
112140 }
113141
114142 return nil , ErrInValidData
115143}
116144
117145// Embeddings returns embeddings for every object in EmbeddingRequest.
118- func (c * Client ) Embeddings (ctx context.Context , embReq * EmbeddingRequest ) ([] * Embedding , error ) {
146+ func (c * Client ) Embeddings (ctx context.Context , embReq * EmbeddingRequest ) (* EmbeddingResponse , error ) {
119147 u , err := url .Parse (c .baseURL + "/" + c .version + "/embeddings" )
120148 if err != nil {
121149 return nil , err
@@ -141,9 +169,9 @@ func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) ([]*E
141169
142170 switch embReq .EncodingFormat {
143171 case EncodingBase64 :
144- return toEmbeddings [ EmbeddingResponse [EmbeddingString ]](resp .Body )
172+ return toEmbeddingResp [ EmbeddingResponseGen [EmbeddingString ]](resp .Body )
145173 case EncodingFloat :
146- return toEmbeddings [ EmbeddingResponse [[]float64 ]](resp .Body )
174+ return toEmbeddingResp [ EmbeddingResponseGen [[]float64 ]](resp .Body )
147175 }
148176
149177 return nil , ErrUnsupportedEncoding
0 commit comments