Skip to content

Commit 23f68bc

Browse files
authored
cleanup: dedupe a lot of code and add some tests (#7)
Signed-off-by: Milos Gajdos <[email protected]>
1 parent 329b912 commit 23f68bc

File tree

16 files changed

+428
-187
lines changed

16 files changed

+428
-187
lines changed

cohere/client.go

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
package cohere
22

33
import (
4-
"bytes"
5-
"context"
64
"encoding/json"
7-
"fmt"
8-
"io"
95
"net/http"
106
"os"
117
)
@@ -62,58 +58,6 @@ func (c *Client) WithHTTPClient(httpClient *http.Client) *Client {
6258
return c
6359
}
6460

65-
// ReqOption is http requestion functional option.
66-
type ReqOption func(*http.Request)
67-
68-
// WithSetHeader sets the header key to value val.
69-
func WithSetHeader(key, val string) ReqOption {
70-
return func(req *http.Request) {
71-
if req.Header == nil {
72-
req.Header = make(http.Header)
73-
}
74-
req.Header.Set(key, val)
75-
}
76-
}
77-
78-
// WithAddHeader adds the val to key header.
79-
func WithAddHeader(key, val string) ReqOption {
80-
return func(req *http.Request) {
81-
if req.Header == nil {
82-
req.Header = make(http.Header)
83-
}
84-
req.Header.Add(key, val)
85-
}
86-
}
87-
88-
func (c *Client) newRequest(ctx context.Context, method, url string, body io.Reader, opts ...ReqOption) (*http.Request, error) {
89-
if ctx == nil {
90-
ctx = context.Background()
91-
}
92-
if body == nil {
93-
body = &bytes.Reader{}
94-
}
95-
96-
req, err := http.NewRequestWithContext(ctx, method, url, body)
97-
if err != nil {
98-
return nil, err
99-
}
100-
101-
for _, setOption := range opts {
102-
setOption(req)
103-
}
104-
105-
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
106-
req.Header.Set("Accept", "application/json; charset=utf-8")
107-
if body != nil {
108-
// if no content-type is specified we default to json
109-
if ct := req.Header.Get("Content-Type"); len(ct) == 0 {
110-
req.Header.Set("Content-Type", "application/json; charset=utf-8")
111-
}
112-
}
113-
114-
return req, nil
115-
}
116-
11761
func (c *Client) doRequest(req *http.Request) (*http.Response, error) {
11862
resp, err := c.hc.Do(req)
11963
if err != nil {

cohere/client_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package cohere
2+
3+
import (
4+
"net/http"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
const (
11+
cohereAPIKey = "somekey"
12+
)
13+
14+
func TestClient(t *testing.T) {
15+
t.Setenv("COHERE_API_KEY", cohereAPIKey)
16+
17+
t.Run("API key", func(t *testing.T) {
18+
c := NewClient()
19+
assert.Equal(t, c.apiKey, cohereAPIKey)
20+
21+
testVal := "foo"
22+
c.WithAPIKey(testVal)
23+
assert.Equal(t, c.apiKey, testVal)
24+
})
25+
26+
t.Run("BaseURL", func(t *testing.T) {
27+
c := NewClient()
28+
assert.Equal(t, c.baseURL, BaseURL)
29+
30+
testVal := "http://foo"
31+
c.WithBaseURL(testVal)
32+
assert.Equal(t, c.baseURL, testVal)
33+
})
34+
35+
t.Run("version", func(t *testing.T) {
36+
c := NewClient()
37+
assert.Equal(t, c.version, EmbedAPIVersion)
38+
39+
testVal := "v3"
40+
c.WithVersion(testVal)
41+
assert.Equal(t, c.version, testVal)
42+
})
43+
44+
t.Run("http client", func(t *testing.T) {
45+
c := NewClient()
46+
assert.NotNil(t, c.hc)
47+
48+
testVal := &http.Client{}
49+
c.WithHTTPClient(testVal)
50+
assert.NotNil(t, c.hc)
51+
})
52+
}

cohere/embedding.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/url"
99

1010
"github.com/milosgajdos/go-embeddings"
11+
"github.com/milosgajdos/go-embeddings/request"
1112
)
1213

1314
// EmbeddingRequest sent to API endpoint.
@@ -63,7 +64,11 @@ func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) (*Emb
6364
return nil, err
6465
}
6566

66-
req, err := c.newRequest(ctx, http.MethodPost, u.String(), body)
67+
options := []request.Option{
68+
request.WithBearer(c.apiKey),
69+
}
70+
71+
req, err := request.NewHTTP(ctx, http.MethodPost, u.String(), body, options...)
6772
if err != nil {
6873
return nil, err
6974
}

embedding.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ type Embedding struct {
77

88
// ToFloat32 returns Embedding verctor as a slice of float32.
99
func (e Embedding) ToFloat32() []float32 {
10-
floats := make([]float32, 0, len(e.Vector))
11-
for _, f := range e.Vector {
12-
floats = append(floats, float32(f))
10+
floats := make([]float32, len(e.Vector))
11+
for i, f := range e.Vector {
12+
floats[i] = float32(f)
1313
}
1414
return floats
1515
}

embedding_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package embeddings
2+
3+
import "testing"
4+
5+
func TestToFloat32(t *testing.T) {
6+
e := Embedding{
7+
Vector: []float64{1.0, 2.0, 3.0},
8+
}
9+
10+
exp := []float32{1.0, 2.0, 3.0}
11+
got := e.ToFloat32()
12+
13+
if len(got) != len(exp) {
14+
t.Fatalf("expected %d vals, got %v", len(exp), len(got))
15+
}
16+
17+
for i, f := range got {
18+
if exp[i] != f {
19+
t.Fatalf("expected %v, got %v", exp, got)
20+
}
21+
}
22+
}

go.mod

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,19 @@ module github.com/milosgajdos/go-embeddings
22

33
go 1.20
44

5-
require golang.org/x/oauth2 v0.15.0
5+
require (
6+
github.com/stretchr/testify v1.8.4
7+
golang.org/x/oauth2 v0.15.0
8+
)
69

710
require (
811
cloud.google.com/go/compute v1.20.1 // indirect
912
cloud.google.com/go/compute/metadata v0.2.3 // indirect
13+
github.com/davecgh/go-spew v1.1.1 // indirect
1014
github.com/golang/protobuf v1.5.3 // indirect
15+
github.com/pmezard/go-difflib v1.0.0 // indirect
1116
golang.org/x/net v0.19.0 // indirect
1217
google.golang.org/appengine v1.6.7 // indirect
1318
google.golang.org/protobuf v1.31.0 // indirect
19+
gopkg.in/yaml.v3 v3.0.1 // indirect
1420
)

go.sum

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,18 @@ cloud.google.com/go/compute v1.20.1 h1:6aKEtlUiwEpJzM001l0yFkpXmUVXaN8W+fbkb2AZN
22
cloud.google.com/go/compute v1.20.1/go.mod h1:4tCnrn48xsqlwSAiLf1HXMQk8CONslYbdiEZc9FEIbM=
33
cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY=
44
cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA=
5+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
6+
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
57
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
68
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
79
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
810
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
911
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
1012
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
13+
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
14+
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
15+
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
16+
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
1117
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
1218
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
1319
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
@@ -25,3 +31,7 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
2531
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
2632
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
2733
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
34+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
35+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
36+
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
37+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

gomod2nix.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,18 @@ schema = 3
77
[mod."cloud.google.com/go/compute/metadata"]
88
version = "v0.2.3"
99
hash = "sha256-kYB1FTQRdTDqCqJzSU/jJYbVUGyxbkASUKbEs36FUyU="
10+
[mod."github.com/davecgh/go-spew"]
11+
version = "v1.1.1"
12+
hash = "sha256-nhzSUrE1fCkN0+RL04N4h8jWmRFPPPWbCuDc7Ss0akI="
1013
[mod."github.com/golang/protobuf"]
1114
version = "v1.5.3"
1215
hash = "sha256-svogITcP4orUIsJFjMtp+Uv1+fKJv2Q5Zwf2dMqnpOQ="
16+
[mod."github.com/pmezard/go-difflib"]
17+
version = "v1.0.0"
18+
hash = "sha256-/FtmHnaGjdvEIKAJtrUfEhV7EVo5A/eYrtdnUkuxLDA="
19+
[mod."github.com/stretchr/testify"]
20+
version = "v1.8.4"
21+
hash = "sha256-MoOmRzbz9QgiJ+OOBo5h5/LbilhJfRUryvzHJmXAWjo="
1322
[mod."golang.org/x/net"]
1423
version = "v0.19.0"
1524
hash = "sha256-3M5rKEvJx4cO/q+06cGjR5sxF5JpnUWY0+fQttrWdT4="
@@ -22,3 +31,6 @@ schema = 3
2231
[mod."google.golang.org/protobuf"]
2332
version = "v1.31.0"
2433
hash = "sha256-UdIk+xRaMfdhVICvKRk1THe3R1VU+lWD8hqoW/y8jT0="
34+
[mod."gopkg.in/yaml.v3"]
35+
version = "v3.0.1"
36+
hash = "sha256-FqL9TKYJ0XkNwJFnq9j0VvJ5ZUU1RvH/52h/f5bkYAU="

openai/client.go

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
package openai
22

33
import (
4-
"bytes"
5-
"context"
64
"encoding/json"
7-
"fmt"
8-
"io"
95
"net/http"
106
"os"
117
)
@@ -72,62 +68,6 @@ func (c *Client) WithHTTPClient(httpClient *http.Client) *Client {
7268
return c
7369
}
7470

75-
// ReqOption is http requestion functional option.
76-
type ReqOption func(*http.Request)
77-
78-
// WithSetHeader sets the header key to value val.
79-
func WithSetHeader(key, val string) ReqOption {
80-
return func(req *http.Request) {
81-
if req.Header == nil {
82-
req.Header = make(http.Header)
83-
}
84-
req.Header.Set(key, val)
85-
}
86-
}
87-
88-
// WithAddHeader adds the val to key header.
89-
func WithAddHeader(key, val string) ReqOption {
90-
return func(req *http.Request) {
91-
if req.Header == nil {
92-
req.Header = make(http.Header)
93-
}
94-
req.Header.Add(key, val)
95-
}
96-
}
97-
98-
func (c *Client) newRequest(ctx context.Context, method, uri string, body io.Reader, opts ...ReqOption) (*http.Request, error) {
99-
if ctx == nil {
100-
ctx = context.Background()
101-
}
102-
if body == nil {
103-
body = &bytes.Reader{}
104-
}
105-
106-
req, err := http.NewRequestWithContext(ctx, method, uri, body)
107-
if err != nil {
108-
return nil, err
109-
}
110-
111-
for _, setOption := range opts {
112-
setOption(req)
113-
}
114-
115-
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
116-
if len(c.orgID) != 0 {
117-
req.Header.Set("OpenAI-Organization", c.orgID)
118-
}
119-
120-
req.Header.Set("Accept", "application/json; charset=utf-8")
121-
if body != nil {
122-
// if no content-type is specified we default to json
123-
if ct := req.Header.Get("Content-Type"); len(ct) == 0 {
124-
req.Header.Set("Content-Type", "application/json; charset=utf-8")
125-
}
126-
}
127-
128-
return req, nil
129-
}
130-
13171
func (c *Client) doRequest(req *http.Request) (*http.Response, error) {
13272
resp, err := c.hc.Do(req)
13373
if err != nil {

openai/client_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package openai
2+
3+
import (
4+
"net/http"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
const (
11+
openaiKey = "t0ps3cr3tk3y"
12+
)
13+
14+
func TestClient(t *testing.T) {
15+
t.Setenv("OPENAI_API_KEY", openaiKey)
16+
17+
t.Run("API key", func(t *testing.T) {
18+
c := NewClient()
19+
assert.Equal(t, c.apiKey, openaiKey)
20+
21+
testVal := "foo"
22+
c.WithAPIKey(testVal)
23+
assert.Equal(t, c.apiKey, testVal)
24+
})
25+
26+
t.Run("BaseURL", func(t *testing.T) {
27+
c := NewClient()
28+
assert.Equal(t, c.baseURL, BaseURL)
29+
30+
testVal := "http://foo"
31+
c.WithBaseURL(testVal)
32+
assert.Equal(t, c.baseURL, testVal)
33+
})
34+
35+
t.Run("version", func(t *testing.T) {
36+
c := NewClient()
37+
assert.Equal(t, c.version, EmbedAPIVersion)
38+
39+
testVal := "v3"
40+
c.WithVersion(testVal)
41+
assert.Equal(t, c.version, testVal)
42+
})
43+
44+
t.Run("orgID", func(t *testing.T) {
45+
c := NewClient()
46+
assert.Equal(t, c.orgID, "")
47+
48+
testVal := "orgID"
49+
c.WithOrgID(testVal)
50+
assert.Equal(t, c.orgID, testVal)
51+
})
52+
53+
t.Run("http client", func(t *testing.T) {
54+
c := NewClient()
55+
assert.NotNil(t, c.hc)
56+
57+
testVal := &http.Client{}
58+
c.WithHTTPClient(testVal)
59+
assert.NotNil(t, c.hc)
60+
})
61+
}

0 commit comments

Comments
 (0)