Skip to content

Commit 788ddef

Browse files
support ark rerank
support ark rerank support ark rerank
1 parent 19c63a1 commit 788ddef

File tree

5 files changed

+106
-9
lines changed

5 files changed

+106
-9
lines changed

backend/application/base/appinfra/app_infra.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ import (
5959
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/builtin"
6060
"github.com/coze-dev/coze-studio/backend/infra/impl/document/parser/ppstructure"
6161
"github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/rrf"
62+
vikingReranker "github.com/coze-dev/coze-studio/backend/infra/impl/document/rerank/vikingdb"
6263
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/elasticsearch"
6364
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/milvus"
6465
"github.com/coze-dev/coze-studio/backend/infra/impl/document/searchstore/vikingdb"
@@ -145,7 +146,7 @@ func Init(ctx context.Context) (*AppDependencies, error) {
145146
return nil, fmt.Errorf("init knowledge event bus producer failed, err=%w", err)
146147
}
147148

148-
deps.Reranker = rrf.NewRRFReranker(0)
149+
deps.Reranker = initReranker()
149150

150151
deps.Rewriter, err = initRewriter(ctx)
151152
if err != nil {
@@ -207,6 +208,26 @@ func initSearchStoreManagers(ctx context.Context, es es.Client) ([]searchstore.M
207208
return []searchstore.Manager{esSearchstoreManager, mgr}, nil
208209
}
209210

211+
func initReranker() rerank.Reranker {
212+
rerankerType := os.Getenv("RERANK_TYPE")
213+
switch rerankerType {
214+
case "vikingdb":
215+
return vikingReranker.NewReranker(getVikingRerankerConfig())
216+
case "rrf":
217+
return rrf.NewRRFReranker(0)
218+
default:
219+
return rrf.NewRRFReranker(0)
220+
}
221+
}
222+
func getVikingRerankerConfig() *vikingReranker.Config {
223+
return &vikingReranker.Config{
224+
AK: os.Getenv("VIKINGDB_RERANK_AK"),
225+
SK: os.Getenv("VIKINGDB_RERANK_SK"),
226+
Domain: os.Getenv("VIKINGDB_RERANK_HOST"),
227+
Region: os.Getenv("VIKINGDB_RERANK_REGION"),
228+
Model: os.Getenv("VIKINGDB_RERANK_MODEL"),
229+
}
230+
}
210231
func initRewriter(ctx context.Context) (messages2query.MessagesToQuery, error) {
211232
rewriterChatModel, _, err := getBuiltinChatModel(ctx, "M2Q_")
212233
if err != nil {

backend/domain/knowledge/service/retrieve.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,10 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
390390
}
391391
replaceMap[doc.Name].ColumnMap[doc.TableInfo.Columns[i].Name] = convert.ColumnIDToRDBField(doc.TableInfo.Columns[i].ID)
392392
}
393+
virtualColumnMap := map[string]*entity.TableColumn{}
394+
for i := range doc.TableInfo.Columns {
395+
virtualColumnMap[convert.ColumnIDToRDBField(doc.TableInfo.Columns[i].ID)] = doc.TableInfo.Columns[i]
396+
}
393397
parsedSQL, err := sqlparser.NewSQLParser().ParseAndModifySQL(sql, replaceMap)
394398
if err != nil {
395399
logs.CtxErrorf(ctx, "parse sql failed: %v", err)
@@ -423,6 +427,32 @@ func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocum
423427
prefix := "sql:" + sql + ";result:"
424428
d.Content = prefix + string(byteData)
425429
} else {
430+
transferMap := map[string]string{}
431+
for cName, val := range resp.ResultSet.Rows[i] {
432+
column, found := virtualColumnMap[cName]
433+
if !found {
434+
logs.CtxInfof(ctx, "column not found, name: %s", cName)
435+
continue
436+
}
437+
columnData, err := convert.ParseAnyData(column, val)
438+
if err != nil {
439+
logs.CtxErrorf(ctx, "parse any data failed: %v", err)
440+
return nil, errorx.New(errno.ErrKnowledgeColumnParseFailCode, errorx.KV("msg", err.Error()))
441+
}
442+
if columnData.Type == document.TableColumnTypeString {
443+
columnData.ValString = ptr.Of(k.formatSliceContent(ctx, columnData.GetStringValue()))
444+
}
445+
if columnData.Type == document.TableColumnTypeImage {
446+
columnData.ValImage = ptr.Of(k.formatSliceContent(ctx, columnData.GetStringValue()))
447+
}
448+
transferMap[column.Name] = columnData.GetNullableStringValue()
449+
}
450+
byteData, err := sonic.Marshal(transferMap)
451+
if err != nil {
452+
logs.CtxErrorf(ctx, "marshal sql resp failed: %v", err)
453+
return nil, err
454+
}
455+
d.Content = string(byteData)
426456
d.ID = strconv.FormatInt(id, 10)
427457
}
428458
d.WithScore(1)

backend/infra/impl/document/rerank/vikingdb/vikingdb.go

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,23 @@ import (
3333
)
3434

3535
type Config struct {
36-
AK string
37-
SK string
38-
36+
AK string
37+
SK string
38+
Domain string
39+
Model string
3940
Region string // default cn-north-1
4041
}
4142

4243
func NewReranker(config *Config) rerank.Reranker {
4344
if config.Region == "" {
4445
config.Region = "cn-north-1"
4546
}
47+
if config.Domain == "" {
48+
config.Domain = domain
49+
}
50+
if config.Model == "" {
51+
config.Model = defaultModel
52+
}
4653
return &reranker{config: config}
4754
}
4855

@@ -78,12 +85,32 @@ type rerankResp struct {
7885
func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
7986
rReq := &rerankReq{
8087
Datas: make([]rerankData, 0, len(req.Data)),
81-
RerankModel: defaultModel,
88+
RerankModel: r.config.Model,
8289
}
83-
90+
sorted := make([]*rerank.Data, 0)
8491
var flat []*rerank.Data
92+
visited := map[string]bool{}
8593
for _, channel := range req.Data {
86-
flat = append(flat, channel...)
94+
if len(channel) == 0 {
95+
continue
96+
}
97+
for _, item := range channel {
98+
if item == nil || item.Document == nil {
99+
continue
100+
}
101+
if item.Document.ID == "" {
102+
sorted = append(sorted, &rerank.Data{
103+
Document: item.Document,
104+
Score: 1,
105+
})
106+
continue
107+
}
108+
if visited[item.Document.ID] {
109+
continue
110+
}
111+
visited[item.Document.ID] = true
112+
flat = append(flat, item)
113+
}
87114
}
88115

89116
for _, item := range flat {
@@ -117,7 +144,6 @@ func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Res
117144
return nil, fmt.Errorf("[Rerank] failed, code=%d, msg=%v", rResp.Code, rResp.Message)
118145
}
119146

120-
sorted := make([]*rerank.Data, 0, len(rResp.Data.Scores))
121147
for i, score := range rResp.Data.Scores {
122148
sorted = append(sorted, &rerank.Data{
123149
Document: flat[i].Document,
@@ -143,7 +169,7 @@ func (r *reranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Res
143169
func (r *reranker) prepareRequest(body []byte) *http.Request {
144170
u := url.URL{
145171
Scheme: "https",
146-
Host: domain,
172+
Host: r.config.Domain,
147173
Path: "/api/knowledge/service/rerank",
148174
}
149175
req, _ := http.NewRequest(http.MethodPost, u.String(), bytes.NewReader(body))

docker/.env.debug.example

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,16 @@ export GEMINI_EMBEDDING_LOCATION="" # (string, optional) Gemini
139139
export HTTP_EMBEDDING_ADDR="" # (string, required) http embedding address
140140
export HTTP_EMBEDDING_DIMS=1024 # (string, required) http embedding dimensions
141141

142+
# Settings for Rerank
143+
# If you want to use the rerank-related functions in the knowledge base feature,You need to set up the rerank configuration.
144+
export RERANK_TYPE="" # current support `vikingdb`,`rrf`,default:rrf
145+
# vikingdb rerank
146+
export VIKINGDB_RERANK_HOST="" # optional,default:api-knowledgebase.mlp.cn-beijing.volces.com
147+
export VIKINGDB_RERANK_REGION="" # optional,default:cn-north-1
148+
export VIKINGDB_RERANK_AK="" # required
149+
export VIKINGDB_RERANK_SK="" # required
150+
export VIKINGDB_RERANK_MODEL="" # optional,default:base-multilingual-rerank,also support m3-v2-rerank
151+
142152
# Settings for OCR
143153
# If you want to use the OCR-related functions in the knowledge base feature,You need to set up the OCR configuration.
144154
# Currently, Coze Studio has built-in Volcano OCR.

docker/.env.example

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,16 @@ export GEMINI_EMBEDDING_LOCATION="" # (string, optional) Gemini
137137
export HTTP_EMBEDDING_ADDR="" # (string, required) http embedding address
138138
export HTTP_EMBEDDING_DIMS=1024 # (string, required) http embedding dimensions
139139

140+
# Settings for Rerank
141+
# If you want to use the rerank-related functions in the knowledge base feature,You need to set up the rerank configuration.
142+
export RERANK_TYPE="" # current support `vikingdb`,`rrf`,default:rrf
143+
# vikingdb rerank
144+
export VIKINGDB_RERANK_HOST="" # optional,default:api-knowledgebase.mlp.cn-beijing.volces.com
145+
export VIKINGDB_RERANK_REGION="" # optional,default:cn-north-1
146+
export VIKINGDB_RERANK_AK="" # required
147+
export VIKINGDB_RERANK_SK="" # required
148+
export VIKINGDB_RERANK_MODEL="" # optional,default:base-multilingual-rerank,also support m3-v2-rerank
149+
140150
# Settings for OCR
141151
# If you want to use the OCR-related functions in the knowledge base feature,You need to set up the OCR configuration.
142152
# Currently, Coze Studio has built-in Volcano OCR.

0 commit comments

Comments
 (0)