Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions components/embedding/openai/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ type EmbeddingConfig struct {
// User is a unique identifier representing your end-user
// Optional. Helps OpenAI monitor and detect abuse
User *string `json:"user,omitempty"`

// BatchSize specifies the number of texts to embed in a single request
// Optional.
BatchSize int `json:"batch_size,omitempty"`
}

var _ embedding.Embedder = (*Embedder)(nil)
Expand Down Expand Up @@ -114,6 +118,7 @@ func NewEmbedder(ctx context.Context, config *EmbeddingConfig) (*Embedder, error
EncodingFormat: config.EncodingFormat,
Dimensions: config.Dimensions,
User: config.User,
BatchSize: config.BatchSize,
}
}
cli, err := openai.NewEmbeddingClient(ctx, nConf)
Expand Down
59 changes: 41 additions & 18 deletions libs/acl/openai/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ import (
"context"
"net/http"

"github.com/meguminnnnnnnnn/go-openai"

"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/embedding"
"github.com/meguminnnnnnnnn/go-openai"
)

type EmbeddingEncodingFormat string
Expand Down Expand Up @@ -77,6 +76,10 @@ type EmbeddingConfig struct {
// User is a unique identifier representing your end-user
// Optional. Helps OpenAI monitor and detect abuse
User *string `json:"user,omitempty"`

// BatchSize specifies the number of texts to embed in a single request
// Optional.
BatchSize int `json:"batch_size,omitempty"`
}

var _ embedding.Embedder = (*EmbeddingClient)(nil)
Expand Down Expand Up @@ -131,7 +134,6 @@ func (e *EmbeddingClient) EmbedStrings(ctx context.Context, texts []string, opts
options = embedding.GetCommonOptions(options, opts...)

req := &openai.EmbeddingRequest{
Input: texts,
Model: openai.EmbeddingModel(*options.Model),
User: dereferenceOrZero(e.config.User),
EncodingFormat: openai.EmbeddingEncodingFormat(dereferenceOrDefault(e.config.EncodingFormat, EmbeddingEncodingFormatFloat)),
Expand All @@ -143,29 +145,50 @@ func (e *EmbeddingClient) EmbedStrings(ctx context.Context, texts []string, opts
EncodingFormat: string(req.EncodingFormat),
}

embeddings = make([][]float64, 0, len(texts))
usage := &embedding.TokenUsage{
PromptTokens: 0,
CompletionTokens: 0,
TotalTokens: 0,
}

var batchSize int
if e.config.BatchSize == 0 {
batchSize = len(texts)
} else {
batchSize = e.config.BatchSize
}

ctx = callbacks.OnStart(ctx, &embedding.CallbackInput{
Texts: texts,
Config: conf,
})

resp, err := e.cli.CreateEmbeddings(ctx, *req)
if err != nil {
return nil, err
}
for i := 0; i < len(texts); i += batchSize {
idx := i
var end int
if idx+batchSize > len(texts) {
end = len(texts)
} else {
end = idx + batchSize
}
req.Input = texts[idx:end]
resp, err2 := e.cli.CreateEmbeddings(ctx, *req)
if err2 != nil {
return nil, err2
}

embeddings = make([][]float64, len(resp.Data))
for i, d := range resp.Data {
res := make([]float64, len(d.Embedding))
for j, emb := range d.Embedding {
res[j] = float64(emb)
for _, d := range resp.Data {
res := make([]float64, len(d.Embedding))
for k, emb := range d.Embedding {
res[k] = float64(emb)
}
embeddings = append(embeddings, res)
}
embeddings[i] = res
}

usage := &embedding.TokenUsage{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
usage.PromptTokens += resp.Usage.PromptTokens
usage.CompletionTokens += resp.Usage.CompletionTokens
usage.TotalTokens += resp.Usage.TotalTokens
}

_ = callbacks.OnEnd(ctx, &embedding.CallbackOutput{
Expand Down
1 change: 1 addition & 0 deletions libs/acl/openai/embedding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func TestEmbedStrings(t *testing.T) {
APIKey: "{your-api-key}",
APIVersion: "2024-06-01",
Model: "gpt-4o-2024-05-13",
BatchSize: 100,
})

assert.NoError(t, err)
Expand Down