Skip to content

Commit b53d156

Browse files
authored
Add version support and change Embedding.Embedding to Vector (#3)
Signed-off-by: Milos Gajdos <[email protected]>
1 parent 8448691 commit b53d156

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

openai/client.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ import (
1212

1313
const (
1414
// BaseURL is OpenAI HTTP API base URL.
15-
BaseURL = "https://api.openai.com/v1"
15+
BaseURL = "https://api.openai.com"
16+
// EmbedAPIVersion is the latest stable embedding API version.
17+
EmbedAPIVersion = "v1"
1618
// Org header
1719
OrgHeader = "OpenAI-Organization"
1820
)
@@ -21,6 +23,7 @@ const (
2123
type Client struct {
2224
apiKey string
2325
baseURL string
26+
version string
2427
orgID string
2528
hc *http.Client
2629
}
@@ -34,6 +37,7 @@ func NewClient() (*Client, error) {
3437
return &Client{
3538
apiKey: os.Getenv("OPENAI_API_KEY"),
3639
baseURL: BaseURL,
40+
version: EmbedAPIVersion,
3741
orgID: "",
3842
hc: &http.Client{},
3943
}, nil
@@ -51,6 +55,12 @@ func (c *Client) WithBaseURL(baseURL string) *Client {
5155
return c
5256
}
5357

58+
// WithVersion sets the API version.
59+
func (c *Client) WithVersion(version string) *Client {
60+
c.version = version
61+
return c
62+
}
63+
5464
// WithOrgID sets the organization ID.
5565
func (c *Client) WithOrgID(orgID string) *Client {
5666
c.orgID = orgID

openai/embedding.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ type Usage struct {
1919
TotalTokens int `json:"total_tokens"`
2020
}
2121

22-
// Embedding is openai API vector embedding.
22+
// Embedding is openai API embedding.
2323
type Embedding struct {
24-
Object string `json:"object"`
25-
Index int `json:"index"`
26-
Embedding []float64 `json:"embedding"`
24+
Object string `json:"object"`
25+
Index int `json:"index"`
26+
Vector []float64 `json:"vector"`
2727
}
2828

2929
// EmbeddingString is base64 encoded embedding.
@@ -87,9 +87,9 @@ func ToEmbeddings[T any](resp io.Reader) ([]*Embedding, error) {
8787
return nil, err
8888
}
8989
emb := &Embedding{
90-
Object: d.Object,
91-
Index: d.Index,
92-
Embedding: floats,
90+
Object: d.Object,
91+
Index: d.Index,
92+
Vector: floats,
9393
}
9494
embs = append(embs, emb)
9595
}
@@ -98,9 +98,9 @@ func ToEmbeddings[T any](resp io.Reader) ([]*Embedding, error) {
9898
embs := make([]*Embedding, 0, len(e.Data))
9999
for _, d := range e.Data {
100100
emb := &Embedding{
101-
Object: d.Object,
102-
Index: d.Index,
103-
Embedding: d.Embedding,
101+
Object: d.Object,
102+
Index: d.Index,
103+
Vector: d.Embedding,
104104
}
105105
embs = append(embs, emb)
106106
}
@@ -112,7 +112,7 @@ func ToEmbeddings[T any](resp io.Reader) ([]*Embedding, error) {
112112

113113
// Embeddings returns embeddings for every object in EmbeddingRequest.
114114
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) ([]*Embedding, error) {
115-
u, err := url.Parse(c.baseURL + "/embeddings")
115+
u, err := url.Parse(c.baseURL + "/" + c.version + "/embeddings")
116116
if err != nil {
117117
return nil, err
118118
}

0 commit comments

Comments
 (0)