Skip to content

Commit 8448691

Browse files
authored
Add openai embeddings (#1)
Add OpenAI API client Signed-off-by: Milos Gajdos <[email protected]>
1 parent a0156bd commit 8448691

File tree

5 files changed

+371
-0
lines changed

5 files changed

+371
-0
lines changed

cmd/openai/main.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"flag"
6+
"fmt"
7+
"log"
8+
9+
"github.com/milosgajdos/go-embeddings/openai"
10+
)
11+
12+
var (
13+
input string
14+
model string
15+
encoding string
16+
)
17+
18+
func init() {
19+
flag.StringVar(&input, "input", "", "input data")
20+
flag.StringVar(&model, "model", string(openai.TextAdaV2), "model name")
21+
flag.StringVar(&encoding, "encoding", string(openai.EncodingFloat), "encoding format")
22+
}
23+
24+
func main() {
25+
flag.Parse()
26+
27+
c, err := openai.NewClient()
28+
if err != nil {
29+
log.Fatal(err)
30+
}
31+
32+
embReq := &openai.EmbeddingRequest{
33+
Input: input,
34+
Model: openai.Model(model),
35+
EncodingFormat: openai.EncodingFormat(encoding),
36+
}
37+
38+
embs, err := c.Embeddings(context.Background(), embReq)
39+
if err != nil {
40+
log.Fatal(err)
41+
}
42+
43+
fmt.Printf("got %d embeddings", len(embs))
44+
}

openai/client.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package openai
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"fmt"
8+
"io"
9+
"net/http"
10+
"os"
11+
)
12+
13+
const (
14+
// BaseURL is OpenAI HTTP API base URL.
15+
BaseURL = "https://api.openai.com/v1"
16+
// Org header
17+
OrgHeader = "OpenAI-Organization"
18+
)
19+
20+
// Client is OpenAI HTTP API client.
21+
type Client struct {
22+
apiKey string
23+
baseURL string
24+
orgID string
25+
hc *http.Client
26+
}
27+
28+
// NewClient creates a new HTTP client and returns it.
29+
// It reads the OpenAI API key from OPENAI_API_KEY env var
30+
// and uses the default Go http.Client.
31+
// You can override the default options by using the
32+
// client methods.
33+
func NewClient() (*Client, error) {
34+
return &Client{
35+
apiKey: os.Getenv("OPENAI_API_KEY"),
36+
baseURL: BaseURL,
37+
orgID: "",
38+
hc: &http.Client{},
39+
}, nil
40+
}
41+
42+
// WithAPIKey sets the API key.
43+
func (c *Client) WithAPIKey(apiKey string) *Client {
44+
c.apiKey = apiKey
45+
return c
46+
}
47+
48+
// WithBaseURL sets the API base URL.
49+
func (c *Client) WithBaseURL(baseURL string) *Client {
50+
c.baseURL = baseURL
51+
return c
52+
}
53+
54+
// WithOrgID sets the organization ID.
55+
func (c *Client) WithOrgID(orgID string) *Client {
56+
c.orgID = orgID
57+
return c
58+
}
59+
60+
// WithHTTPClient sets the HTTP client.
61+
func (c *Client) WithHTTPClient(httpClient *http.Client) *Client {
62+
c.hc = httpClient
63+
return c
64+
}
65+
66+
// ReqOption is http requestion functional option.
67+
type ReqOption func(*http.Request)
68+
69+
// WithSetHeader sets the header key to value val.
70+
func WithSetHeader(key, val string) ReqOption {
71+
return func(req *http.Request) {
72+
if req.Header == nil {
73+
req.Header = make(http.Header)
74+
}
75+
req.Header.Set(key, val)
76+
}
77+
}
78+
79+
// WithAddHeader adds the val to key header.
80+
func WithAddHeader(key, val string) ReqOption {
81+
return func(req *http.Request) {
82+
if req.Header == nil {
83+
req.Header = make(http.Header)
84+
}
85+
req.Header.Add(key, val)
86+
}
87+
}
88+
89+
func (c *Client) newRequest(ctx context.Context, method, uri string, body io.Reader, opts ...ReqOption) (*http.Request, error) {
90+
if ctx == nil {
91+
ctx = context.Background()
92+
}
93+
if body == nil {
94+
body = &bytes.Reader{}
95+
}
96+
97+
req, err := http.NewRequestWithContext(ctx, method, uri, body)
98+
if err != nil {
99+
return nil, err
100+
}
101+
102+
for _, setOption := range opts {
103+
setOption(req)
104+
}
105+
106+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
107+
if len(c.orgID) != 0 {
108+
req.Header.Set("OpenAI-Organization", c.orgID)
109+
}
110+
111+
req.Header.Set("Accept", "application/json; charset=utf-8")
112+
if body != nil {
113+
// if no content-type is specified we default to json
114+
if ct := req.Header.Get("Content-Type"); len(ct) == 0 {
115+
req.Header.Set("Content-Type", "application/json; charset=utf-8")
116+
}
117+
}
118+
119+
return req, nil
120+
}
121+
122+
func (c *Client) doRequest(req *http.Request) (*http.Response, error) {
123+
resp, err := c.hc.Do(req)
124+
if err != nil {
125+
return nil, err
126+
}
127+
if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusBadRequest {
128+
return resp, nil
129+
}
130+
defer resp.Body.Close()
131+
132+
var apiErr APIError
133+
if err := json.NewDecoder(resp.Body).Decode(&apiErr); err != nil {
134+
return nil, err
135+
}
136+
137+
return nil, apiErr
138+
}

openai/embedding.go

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
package openai
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/base64"
7+
"encoding/binary"
8+
"encoding/json"
9+
"fmt"
10+
"io"
11+
"math"
12+
"net/http"
13+
"net/url"
14+
)
15+
16+
// Usage tracks API token usage.
17+
type Usage struct {
18+
PromptTokens int `json:"prompt_tokens"`
19+
TotalTokens int `json:"total_tokens"`
20+
}
21+
22+
// Embedding is openai API vector embedding.
23+
type Embedding struct {
24+
Object string `json:"object"`
25+
Index int `json:"index"`
26+
Embedding []float64 `json:"embedding"`
27+
}
28+
29+
// EmbeddingString is base64 encoded embedding.
30+
type EmbeddingString string
31+
32+
func (s EmbeddingString) Decode() ([]float64, error) {
33+
decoded, err := base64.StdEncoding.DecodeString(string(s))
34+
if err != nil {
35+
return nil, err
36+
}
37+
38+
if len(decoded)%8 != 0 {
39+
return nil, fmt.Errorf("invalid base64 encoded string length")
40+
}
41+
42+
floats := make([]float64, len(decoded)/8)
43+
44+
for i := 0; i < len(floats); i++ {
45+
bits := binary.LittleEndian.Uint64(decoded[i*8 : (i+1)*8])
46+
floats[i] = math.Float64frombits(bits)
47+
}
48+
49+
return floats, nil
50+
}
51+
52+
// EmbeddingRequest is serialized and sent to the API server.
53+
type EmbeddingRequest struct {
54+
Input any `json:"input"`
55+
Model Model `json:"model"`
56+
User string `json:"user"`
57+
EncodingFormat EncodingFormat `json:"encoding_format,omitempty"`
58+
}
59+
60+
// Data is used for deserializing response data.
61+
type Data[T any] struct {
62+
Object string `json:"object"`
63+
Index int `json:"index"`
64+
Embedding T `json:"embedding"`
65+
}
66+
67+
// EmbeddingResponse is the API response from a Create embeddings request.
68+
type EmbeddingResponse[T any] struct {
69+
Object string `json:"object"`
70+
Data []Data[T] `json:"data"`
71+
Model Model `json:"model"`
72+
Usage Usage `json:"usage"`
73+
}
74+
75+
func ToEmbeddings[T any](resp io.Reader) ([]*Embedding, error) {
76+
data := new(T)
77+
if err := json.NewDecoder(resp).Decode(data); err != nil {
78+
return nil, err
79+
}
80+
81+
switch e := any(data).(type) {
82+
case *EmbeddingResponse[EmbeddingString]:
83+
embs := make([]*Embedding, 0, len(e.Data))
84+
for _, d := range e.Data {
85+
floats, err := d.Embedding.Decode()
86+
if err != nil {
87+
return nil, err
88+
}
89+
emb := &Embedding{
90+
Object: d.Object,
91+
Index: d.Index,
92+
Embedding: floats,
93+
}
94+
embs = append(embs, emb)
95+
}
96+
return embs, nil
97+
case *EmbeddingResponse[[]float64]:
98+
embs := make([]*Embedding, 0, len(e.Data))
99+
for _, d := range e.Data {
100+
emb := &Embedding{
101+
Object: d.Object,
102+
Index: d.Index,
103+
Embedding: d.Embedding,
104+
}
105+
embs = append(embs, emb)
106+
}
107+
return embs, nil
108+
}
109+
110+
return nil, ErrInValidData
111+
}
112+
113+
// Embeddings returns embeddings for every object in EmbeddingRequest.
114+
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) ([]*Embedding, error) {
115+
u, err := url.Parse(c.baseURL + "/embeddings")
116+
if err != nil {
117+
return nil, err
118+
}
119+
120+
var body = &bytes.Buffer{}
121+
enc := json.NewEncoder(body)
122+
enc.SetEscapeHTML(false)
123+
if err := enc.Encode(embReq); err != nil {
124+
return nil, err
125+
}
126+
127+
req, err := c.newRequest(ctx, http.MethodPost, u.String(), body)
128+
if err != nil {
129+
return nil, err
130+
}
131+
132+
resp, err := c.doRequest(req)
133+
if err != nil {
134+
return nil, err
135+
}
136+
defer resp.Body.Close()
137+
138+
switch embReq.EncodingFormat {
139+
case EncodingBase64:
140+
return ToEmbeddings[EmbeddingResponse[EmbeddingString]](resp.Body)
141+
case EncodingFloat:
142+
return ToEmbeddings[EmbeddingResponse[[]float64]](resp.Body)
143+
}
144+
145+
return nil, fmt.Errorf("unknown encoding: %v", embReq.EncodingFormat)
146+
}

openai/error.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package openai
2+
3+
import (
4+
"encoding/json"
5+
"errors"
6+
)
7+
8+
var (
9+
ErrInValidData = errors.New("invalid data")
10+
)
11+
12+
type APIError struct {
13+
Err struct {
14+
Message string `json:"message"`
15+
Type string `json:"type"`
16+
Param *string `json:"param,omitempty"`
17+
Code any `json:"code,omitempty"`
18+
} `json:"error"`
19+
}
20+
21+
func (e APIError) Error() string {
22+
b, err := json.Marshal(e)
23+
if err != nil {
24+
return "unknown error"
25+
}
26+
return string(b)
27+
}

openai/openai.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package openai
2+
3+
// Model is embedding model.
4+
type Model string
5+
6+
const (
7+
TextAdaV2 Model = "text-embedding-ada-002"
8+
)
9+
10+
// EncodingFormat for embedding API requests.
11+
type EncodingFormat string
12+
13+
const (
14+
EncodingFloat EncodingFormat = "float"
15+
EncodingBase64 EncodingFormat = "base64"
16+
)

0 commit comments

Comments
 (0)