Skip to content

Commit a97d91c

Browse files
fix: [2.6] add dim parameter support for Siliconflow & cohere provider (#47081)
#47077 pr: #47080 Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
1 parent bfc6db3 commit a97d91c

File tree

6 files changed

+62
-35
lines changed

6 files changed

+62
-35
lines changed

internal/util/function/embedding/cohere_embedding_provider.go

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@ import (
3333
type CohereEmbeddingProvider struct {
3434
fieldDim int64
3535

36-
client *cohere.CohereClient
37-
url string
38-
modelName string
39-
truncate string
40-
embdType models.EmbeddingType
41-
outputType string
36+
client *cohere.CohereClient
37+
url string
38+
modelName string
39+
truncate string
40+
embedDimParam int64
41+
embdType models.EmbeddingType
42+
outputType string
4243

4344
maxBatch int
4445
timeoutSec int64
@@ -55,11 +56,17 @@ func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem
5556
return nil, err
5657
}
5758
var modelName string
59+
var dim int64
5860
truncate := "END"
5961
for _, param := range functionSchema.Params {
6062
switch strings.ToLower(param.Key) {
6163
case models.ModelNameParamKey:
6264
modelName = param.Value
65+
case models.DimParamKey:
66+
dim, err = models.ParseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name)
67+
if err != nil {
68+
return nil, err
69+
}
6370
case models.TruncateParamKey:
6471
if param.Value != "NONE" && param.Value != "START" && param.Value != "END" {
6572
return nil, fmt.Errorf("Illegal parameters, %s only supports [NONE, START, END]", models.TruncateParamKey)
@@ -91,16 +98,17 @@ func NewCohereEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchem
9198
}()
9299

93100
provider := CohereEmbeddingProvider{
94-
client: c,
95-
url: url,
96-
fieldDim: fieldDim,
97-
modelName: modelName,
98-
truncate: truncate,
99-
embdType: embdType,
100-
outputType: outputType,
101-
maxBatch: 96,
102-
timeoutSec: 30,
103-
extraInfo: extraInfo,
101+
client: c,
102+
url: url,
103+
fieldDim: fieldDim,
104+
modelName: modelName,
105+
truncate: truncate,
106+
embedDimParam: dim,
107+
embdType: embdType,
108+
outputType: outputType,
109+
maxBatch: 96,
110+
timeoutSec: 30,
111+
extraInfo: extraInfo,
104112
}
105113
return &provider, nil
106114
}
@@ -135,7 +143,7 @@ func (provider *CohereEmbeddingProvider) CallEmbedding(ctx context.Context, text
135143
end = numRows
136144
}
137145

138-
resp, err := provider.client.Embedding(provider.url, provider.modelName, texts[i:end], inputType, provider.outputType, provider.truncate, provider.timeoutSec)
146+
resp, err := provider.client.Embedding(provider.url, provider.modelName, texts[i:end], inputType, provider.outputType, provider.truncate, int(provider.embedDimParam), provider.timeoutSec)
139147
if err != nil {
140148
return nil, err
141149
}

internal/util/function/embedding/siliconflow_embedding_provider.go

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,17 @@ func NewSiliconflowEmbeddingProvider(fieldSchema *schemapb.FieldSchema, function
5353
return nil, err
5454
}
5555
var modelName string
56+
var dim int64
5657

5758
for _, param := range functionSchema.Params {
5859
switch strings.ToLower(param.Key) {
5960
case models.ModelNameParamKey:
6061
modelName = param.Value
62+
case models.DimParamKey:
63+
dim, err = models.ParseAndCheckFieldDim(param.Value, fieldDim, fieldSchema.Name)
64+
if err != nil {
65+
return nil, err
66+
}
6167
default:
6268
}
6369
}
@@ -72,13 +78,14 @@ func NewSiliconflowEmbeddingProvider(fieldSchema *schemapb.FieldSchema, function
7278
}
7379

7480
provider := SiliconflowEmbeddingProvider{
75-
client: c,
76-
url: url,
77-
fieldDim: fieldDim,
78-
modelName: modelName,
79-
maxBatch: 32,
80-
timeoutSec: 30,
81-
extraInfo: extraInfo,
81+
client: c,
82+
url: url,
83+
fieldDim: fieldDim,
84+
modelName: modelName,
85+
embedDimParam: dim,
86+
maxBatch: 32,
87+
timeoutSec: 30,
88+
extraInfo: extraInfo,
8289
}
8390
return &provider, nil
8491
}
@@ -99,7 +106,7 @@ func (provider *SiliconflowEmbeddingProvider) CallEmbedding(ctx context.Context,
99106
if end > numRows {
100107
end = numRows
101108
}
102-
resp, err := provider.client.Embedding(provider.url, provider.modelName, texts[i:end], "float", provider.timeoutSec)
109+
resp, err := provider.client.Embedding(provider.url, provider.modelName, texts[i:end], "float", int(provider.embedDimParam), provider.timeoutSec)
103110
if err != nil {
104111
return nil, err
105112
}

internal/util/function/models/cohere/cohere_client.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ func (c *CohereClient) headers() map[string]string {
4444
}
4545
}
4646

47-
func (c *CohereClient) Embedding(url string, modelName string, texts []string, inputType string, outputType string, truncate string, timeoutSec int64) (*EmbeddingResponse, error) {
47+
func (c *CohereClient) Embedding(url string, modelName string, texts []string, inputType string, outputType string, truncate string, dim int, timeoutSec int64) (*EmbeddingResponse, error) {
4848
embClient := newCohereEmbedding(c.apiKey, url)
49-
return embClient.embedding(modelName, texts, inputType, outputType, truncate, c.headers(), timeoutSec)
49+
return embClient.embedding(modelName, texts, inputType, outputType, truncate, dim, c.headers(), timeoutSec)
5050
}
5151

5252
func (c *CohereClient) Rerank(url string, modelName string, query string, texts []string, params map[string]any, timeoutSec int64) (*RerankResponse, error) {
@@ -70,6 +70,8 @@ type EmbeddingRequest struct {
7070
// exactly the maximum input token length for the model.
7171
// If NONE is selected, when the input exceeds the maximum input token length an error will be returned.
7272
Truncate string `json:"truncate,omitempty"`
73+
74+
OutputDimension int `json:"output_dimension,omitempty"`
7375
}
7476

7577
// Currently only float32/int8 is supported
@@ -95,7 +97,7 @@ func newCohereEmbedding(apiKey string, url string) *cohereEmbedding {
9597
}
9698
}
9799

98-
func (c *cohereEmbedding) embedding(modelName string, texts []string, inputType string, outputType string, truncate string, headers map[string]string, timeoutSec int64) (*EmbeddingResponse, error) {
100+
func (c *cohereEmbedding) embedding(modelName string, texts []string, inputType string, outputType string, truncate string, dim int, headers map[string]string, timeoutSec int64) (*EmbeddingResponse, error) {
99101
var r EmbeddingRequest
100102
r.Model = modelName
101103
r.Texts = texts
@@ -104,6 +106,9 @@ func (c *cohereEmbedding) embedding(modelName string, texts []string, inputType
104106
}
105107
r.EmbeddingTypes = []string{outputType}
106108
r.Truncate = truncate
109+
if dim != 0 {
110+
r.OutputDimension = dim
111+
}
107112

108113
res, err := models.PostRequest[EmbeddingResponse](r, c.url, headers, timeoutSec)
109114
if err != nil {

internal/util/function/models/cohere/cohere_client_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func TestEmbeddingOK(t *testing.T) {
6969

7070
{
7171
c, _ := NewCohereClient("mock_key")
72-
ret, err := c.Embedding(url, "cohere-3", []string{"sentence"}, "search_document", "float", "END", 0)
72+
ret, err := c.Embedding(url, "cohere-3", []string{"sentence"}, "search_document", "float", "END", 0, 0)
7373
assert.True(t, err == nil)
7474
assert.Equal(t, ret.Embeddings.Float[0], []float32{0.0, 0.1})
7575
assert.Equal(t, ret.Embeddings.Float[1], []float32{1.0, 1.1})
@@ -86,7 +86,7 @@ func TestEmbeddingFailed(t *testing.T) {
8686

8787
{
8888
c, _ := NewCohereClient("mock_key")
89-
_, err := c.Embedding(url, "cohere-3", []string{"sentence"}, "search_document", "float", "END", 0)
89+
_, err := c.Embedding(url, "cohere-3", []string{"sentence"}, "search_document", "float", "END", 0, 0)
9090
assert.True(t, err != nil)
9191
}
9292
}

internal/util/function/models/siliconflow/siliconflow_client.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ func (c *SiliconflowClient) headers() map[string]string {
4444
}
4545
}
4646

47-
func (c *SiliconflowClient) Embedding(url string, modelName string, texts []string, encodingFormat string, timeoutSec int64) (*EmbeddingResponse, error) {
47+
func (c *SiliconflowClient) Embedding(url string, modelName string, texts []string, encodingFormat string, dim int, timeoutSec int64) (*EmbeddingResponse, error) {
4848
embClient := newSiliconflowEmbedding(c.apiKey, url)
49-
return embClient.embedding(modelName, texts, encodingFormat, c.headers(), timeoutSec)
49+
return embClient.embedding(modelName, texts, encodingFormat, dim, c.headers(), timeoutSec)
5050
}
5151

5252
func (c *SiliconflowClient) Rerank(url string, modelName string, query string, texts []string, params map[string]any, timeoutSec int64) (*RerankResponse, error) {
@@ -62,6 +62,10 @@ type EmbeddingRequest struct {
6262
Input []string `json:"input"`
6363

6464
EncodingFormat string `json:"encoding_format,omitempty"`
65+
66+
// The number of dimensions the resulting output embeddings should have.
67+
// Only supported in some models.
68+
Dimensions int `json:"dimensions,omitempty"`
6569
}
6670

6771
type Usage struct {
@@ -101,11 +105,14 @@ func newSiliconflowEmbedding(apiKey string, url string) *siliconflowEmbedding {
101105
}
102106
}
103107

104-
func (c *siliconflowEmbedding) embedding(modelName string, texts []string, encodingFormat string, headers map[string]string, timeoutSec int64) (*EmbeddingResponse, error) {
108+
func (c *siliconflowEmbedding) embedding(modelName string, texts []string, encodingFormat string, dim int, headers map[string]string, timeoutSec int64) (*EmbeddingResponse, error) {
105109
var r EmbeddingRequest
106110
r.Model = modelName
107111
r.Input = texts
108112
r.EncodingFormat = encodingFormat
113+
if dim != 0 {
114+
r.Dimensions = dim
115+
}
109116

110117
res, err := models.PostRequest[EmbeddingResponse](r, c.url, headers, timeoutSec)
111118
if err != nil {

internal/util/function/models/siliconflow/siliconflow_client_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func TestEmbeddingOK(t *testing.T) {
7474

7575
{
7676
c, _ := NewSiliconflowClient("mock_key")
77-
ret, err := c.Embedding(url, "BAAI/bge-large-zh-v1.5", []string{"sentence"}, "float", 0)
77+
ret, err := c.Embedding(url, "BAAI/bge-large-zh-v1.5", []string{"sentence"}, "float", 0, 0)
7878
assert.True(t, err == nil)
7979
assert.Equal(t, ret.Data[0].Index, 0)
8080
assert.Equal(t, ret.Data[1].Index, 1)
@@ -95,7 +95,7 @@ func TestEmbeddingFailed(t *testing.T) {
9595

9696
{
9797
c, _ := NewSiliconflowClient("mock_key")
98-
_, err := c.Embedding(url, "BAAI/bge-large-zh-v1.5", []string{"sentence"}, "float", 0)
98+
_, err := c.Embedding(url, "BAAI/bge-large-zh-v1.5", []string{"sentence"}, "float", 0, 0)
9999
assert.True(t, err != nil)
100100
}
101101
}

0 commit comments

Comments
 (0)