Skip to content

Commit 32157c3

Browse files
committed
Add hugging face API but it's a bit of a mess
Signed-off-by: Milos Gajdos <[email protected]>
1 parent 23f68bc commit 32157c3

File tree

5 files changed

+273
-0
lines changed

5 files changed

+273
-0
lines changed

cmd/huggingface/main.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"flag"
6+
"fmt"
7+
"log"
8+
9+
"github.com/milosgajdos/go-embeddings/cohere"
10+
"github.com/milosgajdos/go-embeddings/huggingface"
11+
)
12+
13+
var (
14+
input string
15+
model string
16+
wait bool
17+
)
18+
19+
func init() {
20+
flag.StringVar(&input, "input", "what is life", "input data")
21+
flag.StringVar(&model, "model", string(cohere.EnglishV3), "model name")
22+
flag.BoolVar(&wait, "wait", false, "wait for model to start")
23+
}
24+
25+
func main() {
26+
flag.Parse()
27+
28+
c := huggingface.NewClient().
29+
WithModel(model)
30+
31+
embReq := &huggingface.EmbeddingRequest{
32+
Inputs: []string{input},
33+
Options: huggingface.Options{
34+
WaitForModel: &wait,
35+
},
36+
}
37+
38+
embResp, err := c.Embeddings(context.Background(), embReq)
39+
if err != nil {
40+
log.Fatal(err)
41+
}
42+
43+
embs, err := huggingface.ToEmbeddings(embResp)
44+
if err != nil {
45+
log.Fatal(err)
46+
}
47+
48+
fmt.Printf("got %d embeddings", len(embs))
49+
}

huggingface/client.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package huggingface
2+
3+
import (
4+
"encoding/json"
5+
"net/http"
6+
"os"
7+
)
8+
9+
const (
10+
// BaseURL is Cohere HTTP API base URL.
11+
BaseURL = "https://api-inference.huggingface.co/models"
12+
)
13+
14+
// Client is Cohere HTTP API client.
15+
type Client struct {
16+
apiKey string
17+
baseURL string
18+
model string
19+
hc *http.Client
20+
}
21+
22+
// NewClient creates a new HTTP API client and returns it.
23+
// By default it reads the Cohere API key from HUGGINGFACE_API_KEY
24+
// env var and uses the default Go http.Client for making API requests.
25+
// You can override the default options via the client methods.
26+
func NewClient() *Client {
27+
return &Client{
28+
apiKey: os.Getenv("HUGGINGFACE_API_KEY"),
29+
baseURL: BaseURL,
30+
hc: &http.Client{},
31+
}
32+
}
33+
34+
// WithAPIKey sets the API key.
35+
func (c *Client) WithAPIKey(apiKey string) *Client {
36+
c.apiKey = apiKey
37+
return c
38+
}
39+
40+
// WithBaseURL sets the API base URL.
41+
func (c *Client) WithBaseURL(baseURL string) *Client {
42+
c.baseURL = baseURL
43+
return c
44+
}
45+
46+
// WithModel sets the model name
47+
func (c *Client) WithModel(model string) *Client {
48+
c.model = model
49+
return c
50+
}
51+
52+
// WithHTTPClient sets the HTTP client.
53+
func (c *Client) WithHTTPClient(httpClient *http.Client) *Client {
54+
c.hc = httpClient
55+
return c
56+
}
57+
58+
func (c *Client) doRequest(req *http.Request) (*http.Response, error) {
59+
resp, err := c.hc.Do(req)
60+
if err != nil {
61+
return nil, err
62+
}
63+
if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusBadRequest {
64+
return resp, nil
65+
}
66+
defer resp.Body.Close()
67+
68+
var apiErr APIError
69+
if err := json.NewDecoder(resp.Body).Decode(&apiErr); err != nil {
70+
return nil, err
71+
}
72+
73+
return nil, apiErr
74+
}

huggingface/client_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package huggingface
2+
3+
import (
4+
"net/http"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
const (
11+
huggingFaceKey = "somekey"
12+
)
13+
14+
func TestClient(t *testing.T) {
15+
t.Setenv("HUGGINGFACE_API_KEY", huggingFaceKey)
16+
17+
t.Run("API key", func(t *testing.T) {
18+
c := NewClient()
19+
assert.Equal(t, c.apiKey, huggingFaceKey)
20+
21+
testVal := "foo"
22+
c.WithAPIKey(testVal)
23+
assert.Equal(t, c.apiKey, testVal)
24+
})
25+
26+
t.Run("BaseURL", func(t *testing.T) {
27+
c := NewClient()
28+
assert.Equal(t, c.baseURL, BaseURL)
29+
30+
testVal := "http://foo"
31+
c.WithBaseURL(testVal)
32+
assert.Equal(t, c.baseURL, testVal)
33+
})
34+
35+
t.Run("Model", func(t *testing.T) {
36+
c := NewClient()
37+
assert.Equal(t, c.model, "")
38+
39+
testVal := "foo/bar"
40+
c.WithModel(testVal)
41+
assert.Equal(t, c.model, testVal)
42+
})
43+
44+
t.Run("http client", func(t *testing.T) {
45+
c := NewClient()
46+
assert.NotNil(t, c.hc)
47+
48+
testVal := &http.Client{}
49+
c.WithHTTPClient(testVal)
50+
assert.NotNil(t, c.hc)
51+
})
52+
}

huggingface/embedding.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package huggingface
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"net/http"
8+
"net/url"
9+
10+
"github.com/milosgajdos/go-embeddings"
11+
"github.com/milosgajdos/go-embeddings/request"
12+
)
13+
14+
// EmbeddingRequest sent to API endpoint.
15+
type EmbeddingRequest struct {
16+
Inputs []string `json:"inputs"`
17+
Options Options `json:"options,omitempty"`
18+
}
19+
20+
// Options
21+
type Options struct {
22+
WaitForModel *bool `json:"wait_for_model,omitempty"`
23+
}
24+
25+
// EmbedddingResponse is returned by API.
26+
// TODO: hugging face APIs are a mess
27+
type EmbedddingResponse [][][][]float64
28+
29+
// ToEmbeddings converts the raw API response,
30+
// parses it into a slice of embeddings and returns it.
31+
func ToEmbeddings(e *EmbedddingResponse) ([]*embeddings.Embedding, error) {
32+
emb := *e
33+
embs := make([]*embeddings.Embedding, 0, len(emb))
34+
// for i := range emb {
35+
// vals := emb[i]
36+
// floats := make([]float64, len(vals))
37+
// copy(floats, vals)
38+
// emb := &embeddings.Embedding{
39+
// Vector: floats,
40+
// }
41+
// embs = append(embs, emb)
42+
// }
43+
return embs, nil
44+
}
45+
46+
// Embeddings returns embeddings for every object in EmbeddingRequest.
47+
func (c *Client) Embeddings(ctx context.Context, embReq *EmbeddingRequest) (*EmbedddingResponse, error) {
48+
u, err := url.Parse(c.baseURL + "/" + c.model)
49+
if err != nil {
50+
return nil, err
51+
}
52+
53+
var body = &bytes.Buffer{}
54+
enc := json.NewEncoder(body)
55+
enc.SetEscapeHTML(false)
56+
if err := enc.Encode(embReq); err != nil {
57+
return nil, err
58+
}
59+
60+
options := []request.Option{
61+
request.WithBearer(c.apiKey),
62+
}
63+
64+
req, err := request.NewHTTP(ctx, http.MethodPost, u.String(), body, options...)
65+
if err != nil {
66+
return nil, err
67+
}
68+
69+
resp, err := c.doRequest(req)
70+
if err != nil {
71+
return nil, err
72+
}
73+
defer resp.Body.Close()
74+
75+
e := new(EmbedddingResponse)
76+
if err := json.NewDecoder(resp.Body).Decode(e); err != nil {
77+
return nil, err
78+
}
79+
80+
return e, nil
81+
}

huggingface/error.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package huggingface
2+
3+
import "encoding/json"
4+
5+
// APIError is error returned by API
6+
type APIError struct {
7+
Message string `json:"error"`
8+
}
9+
10+
// Error implements errors interface.
11+
func (e APIError) Error() string {
12+
b, err := json.Marshal(e)
13+
if err != nil {
14+
return "unknown error"
15+
}
16+
return string(b)
17+
}

0 commit comments

Comments
 (0)