Skip to content

Commit 6689e87

Browse files
authored
Add Cohere embedding API support. (#2)
* Add Cohere embedding API support. Signed-off-by: Milos Gajdos <[email protected]>
1 parent b53d156 commit 6689e87

File tree

5 files changed

+313
-0
lines changed

5 files changed

+313
-0
lines changed

cmd/cohere/main.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"flag"
6+
"fmt"
7+
"log"
8+
9+
"github.com/milosgajdos/go-embeddings/cohere"
10+
)
11+
12+
var (
13+
input string
14+
model string
15+
truncate string
16+
inputType string
17+
)
18+
19+
func init() {
20+
flag.StringVar(&input, "input", "", "input data")
21+
flag.StringVar(&model, "model", string(cohere.EnglishV3), "model name")
22+
flag.StringVar(&truncate, "truncate", string(cohere.NoneTrunc), "truncate type")
23+
flag.StringVar(&inputType, "input-type", string(cohere.ClusteringInput), "input type")
24+
}
25+
26+
func main() {
27+
flag.Parse()
28+
29+
c, err := cohere.NewClient()
30+
if err != nil {
31+
log.Fatal(err)
32+
}
33+
34+
embReq := &cohere.EmbeddingRequest{
35+
Texts: []string{input},
36+
Model: cohere.Model(model),
37+
InputType: cohere.InputType(inputType),
38+
Truncate: cohere.Truncate(truncate),
39+
}
40+
41+
embs, err := c.Embeddings(context.Background(), embReq)
42+
if err != nil {
43+
log.Fatal(err)
44+
}
45+
46+
fmt.Printf("got %d embeddings", len(embs))
47+
}

cohere/client.go

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
package cohere
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"fmt"
8+
"io"
9+
"net/http"
10+
"os"
11+
)
12+
13+
const (
14+
// BaseURL is Cohere HTTP API base URL.
15+
BaseURL = "https://api.cohere.ai"
16+
// EmbedAPIVersion is the latest stable embedding API version.
17+
EmbedAPIVersion = "v1"
18+
)
19+
20+
// Client is Cohere HTTP API client.
21+
type Client struct {
22+
apiKey string
23+
baseURL string
24+
version string
25+
hc *http.Client
26+
}
27+
28+
// NewClient creates a new HTTP client and returns it.
29+
// It reads the Cohere API key from COHERE_API_KEY env var
30+
// and uses the default Go http.Client.
31+
// You can override the default options by using the
32+
// client methods.
33+
func NewClient() (*Client, error) {
34+
return &Client{
35+
apiKey: os.Getenv("COHERE_API_KEY"),
36+
baseURL: BaseURL,
37+
version: EmbedAPIVersion,
38+
hc: &http.Client{},
39+
}, nil
40+
}
41+
42+
// WithAPIKey sets the API key.
43+
func (c *Client) WithAPIKey(apiKey string) *Client {
44+
c.apiKey = apiKey
45+
return c
46+
}
47+
48+
// WithBaseURL sets the API base URL.
49+
func (c *Client) WithBaseURL(baseURL string) *Client {
50+
c.baseURL = baseURL
51+
return c
52+
}
53+
54+
// WithVersion sets the API version.
55+
func (c *Client) WithVersion(version string) *Client {
56+
c.version = version
57+
return c
58+
}
59+
60+
// WithHTTPClient sets the HTTP client.
61+
func (c *Client) WithHTTPClient(httpClient *http.Client) *Client {
62+
c.hc = httpClient
63+
return c
64+
}
65+
66+
// ReqOption is http requestion functional option.
67+
type ReqOption func(*http.Request)
68+
69+
// WithSetHeader sets the header key to value val.
70+
func WithSetHeader(key, val string) ReqOption {
71+
return func(req *http.Request) {
72+
if req.Header == nil {
73+
req.Header = make(http.Header)
74+
}
75+
req.Header.Set(key, val)
76+
}
77+
}
78+
79+
// WithAddHeader adds the val to key header.
80+
func WithAddHeader(key, val string) ReqOption {
81+
return func(req *http.Request) {
82+
if req.Header == nil {
83+
req.Header = make(http.Header)
84+
}
85+
req.Header.Add(key, val)
86+
}
87+
}
88+
89+
func (c *Client) newRequest(ctx context.Context, method, url string, body io.Reader, opts ...ReqOption) (*http.Request, error) {
90+
if ctx == nil {
91+
ctx = context.Background()
92+
}
93+
if body == nil {
94+
body = &bytes.Reader{}
95+
}
96+
97+
req, err := http.NewRequestWithContext(ctx, method, url, body)
98+
if err != nil {
99+
return nil, err
100+
}
101+
102+
for _, setOption := range opts {
103+
setOption(req)
104+
}
105+
106+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
107+
req.Header.Set("Accept", "application/json; charset=utf-8")
108+
if body != nil {
109+
// if no content-type is specified we default to json
110+
if ct := req.Header.Get("Content-Type"); len(ct) == 0 {
111+
req.Header.Set("Content-Type", "application/json; charset=utf-8")
112+
}
113+
}
114+
115+
return req, nil
116+
}
117+
118+
func (c *Client) doRequest(req *http.Request) (*http.Response, error) {
119+
resp, err := c.hc.Do(req)
120+
if err != nil {
121+
return nil, err
122+
}
123+
if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusBadRequest {
124+
return resp, nil
125+
}
126+
defer resp.Body.Close()
127+
128+
var apiErr APIError
129+
if err := json.NewDecoder(resp.Body).Decode(&apiErr); err != nil {
130+
return nil, err
131+
}
132+
133+
return nil, apiErr
134+
}

cohere/cohere.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package cohere
2+
3+
// Model is embedding model.
4+
type Model string
5+
6+
const (
7+
EnglishV3 Model = "embed-english-v3.0"
8+
MultiLingV3 Model = "embed-multilingual-v3.0"
9+
EnglishLightV3 Model = "embed-english-light-v3.0"
10+
MultiLingLightV3 Model = "embed-multilingual-light-v3.0"
11+
EnglishV2 Model = "embed-english-v2.0"
12+
EnglishLightV2 Model = "embed-english-light-v2.0"
13+
MultiLingV2 Model = "embed-multilingual-v2.0"
14+
)
15+
16+
// InputType is embedding input type.
17+
type InputType string
18+
19+
const (
20+
SearchDocInput InputType = "search_document"
21+
SearchQueryInput InputType = "search_query"
22+
ClassificationInput InputType = "classification"
23+
ClusteringInput InputType = "clustering"
24+
)
25+
26+
// Truncate controls input truncating.
27+
type Truncate string
28+
29+
const (
30+
StartTrunc Truncate = "START"
31+
EndTrunc Truncate = "END"
32+
NoneTrunc Truncate = "NONE"
33+
)

cohere/embedding.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package cohere
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"io"
8+
"net/http"
9+
"net/url"
10+
)
11+
12+
// Embedding is cohere API vector embedding.
13+
type Embedding struct {
14+
Vector []float64 `json:"vector"`
15+
}
16+
17+
// EmbeddingRequest sent to API endpoint.
18+
type EmbeddingRequest struct {
19+
Texts []string `json:"texts"`
20+
Model Model `json:"model,omitempty"`
21+
InputType InputType `json:"input_type"`
22+
Truncate Truncate `json:"truncate,omitempty"`
23+
}
24+
25+
// EmbedddingResponse received from API endpoint.
26+
type EmbedddingResponse struct {
27+
Embeddings [][]float64 `json:"embeddings"`
28+
Meta *Meta `json:"meta,omitempty"`
29+
}
30+
31+
// Meta stores API response metadata
32+
type Meta struct {
33+
APIVersion *APIVersion `json:"api_version,omitempty"`
34+
}
35+
36+
// APIVersion stores metadata API version.
37+
type APIVersion struct {
38+
Version string `json:"version"`
39+
}
40+
41+
func ToEmbeddings(r io.Reader) ([]*Embedding, error) {
42+
var resp EmbedddingResponse
43+
if err := json.NewDecoder(r).Decode(&resp); err != nil {
44+
return nil, err
45+
}
46+
embs := make([]*Embedding, 0, len(resp.Embeddings))
47+
for _, e := range resp.Embeddings {
48+
floats := make([]float64, len(e))
49+
copy(floats, e)
50+
emb := &Embedding{
51+
Vector: floats,
52+
}
53+
embs = append(embs, emb)
54+
}
55+
return embs, nil
56+
}
57+
58+
// Embeddings returns embeddings for every object in EmbeddingRequest.
59+
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) ([]*Embedding, error) {
60+
u, err := url.Parse(c.baseURL + "/" + c.version + "/embed")
61+
if err != nil {
62+
return nil, err
63+
}
64+
65+
var body = &bytes.Buffer{}
66+
enc := json.NewEncoder(body)
67+
enc.SetEscapeHTML(false)
68+
if err := enc.Encode(embReq); err != nil {
69+
return nil, err
70+
}
71+
72+
req, err := c.newRequest(ctx, http.MethodPost, u.String(), body)
73+
if err != nil {
74+
return nil, err
75+
}
76+
77+
resp, err := c.doRequest(req)
78+
if err != nil {
79+
return nil, err
80+
}
81+
defer resp.Body.Close()
82+
83+
return ToEmbeddings(resp.Body)
84+
}

cohere/error.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package cohere
2+
3+
import "encoding/json"
4+
5+
type APIError struct {
6+
Message string `json:"message"`
7+
}
8+
9+
func (e APIError) Error() string {
10+
b, err := json.Marshal(e)
11+
if err != nil {
12+
return "unknown error"
13+
}
14+
return string(b)
15+
}

0 commit comments

Comments
 (0)