Skip to content

Commit ed19db0

Browse files
committed
feat: add function to retrieve default parameters for embedding models
--bug=1063177 --user=刘瑞斌 【知识库】-知识库使用的模型更换维度参数值并重新向量化后,命中测试、检索报错 https://www.tapd.cn/62980211/s/1792117
1 parent 2de6bd2 commit ed19db0

File tree

2 files changed

+36
-21
lines changed

2 files changed

+36
-21
lines changed

apps/knowledge/serializers/common.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,24 +112,51 @@ def to_problem_model_list(self):
112112
], problem_paragraph_mapping_list
113113
return result
114114

115+
def get_embedding_model_default_params(model):
116+
def convert_to_int(value):
117+
if isinstance(value, str):
118+
try:
119+
return int(value)
120+
except ValueError:
121+
return value
122+
return value
123+
124+
return {
125+
p.get('field'): convert_to_int(p.get('default_value'))
126+
for p in model.model_params_form
127+
if p.get('default_value') is not None
128+
}
129+
115130

116131
def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
117132
knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
118133
if len(set([knowledge.embedding_model_id for knowledge in knowledge_list])) > 1:
119134
raise Exception(_('The knowledge base is inconsistent with the vector model'))
120135
if len(knowledge_list) == 0:
121136
raise Exception(_('Knowledge base setting error, please reset the knowledge base'))
122-
return ModelManage.get_model(str(knowledge_list[0].embedding_model_id),
123-
lambda _id: get_model(knowledge_list[0].embedding_model))
137+
138+
default_params = get_embedding_model_default_params(knowledge_list[0].embedding_model)
139+
140+
return ModelManage.get_model(
141+
str(knowledge_list[0].embedding_model_id),
142+
lambda _id: get_model(knowledge_list[0].embedding_model, **{**default_params})
143+
)
124144

125145

126146
def get_embedding_model_by_knowledge_id(knowledge_id: str):
127147
knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first()
128-
return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model))
148+
149+
default_params = get_embedding_model_default_params(knowledge.embedding_model)
150+
151+
return ModelManage.get_model(str(knowledge.embedding_model_id),
152+
lambda _id: get_model(knowledge.embedding_model, **{**default_params}))
129153

130154

131155
def get_embedding_model_by_knowledge(knowledge):
132-
return ModelManage.get_model(str(knowledge.embedding_model_id), lambda _id: get_model(knowledge.embedding_model))
156+
default_params = get_embedding_model_default_params(knowledge.embedding_model)
157+
158+
return ModelManage.get_model(str(knowledge.embedding_model_id),
159+
lambda _id: get_model(knowledge.embedding_model, **{**default_params}))
133160

134161

135162
def get_embedding_model_id_by_knowledge_id(knowledge_id):
@@ -241,7 +268,7 @@ def create_knowledge_index(knowledge_id=None, document_id=None):
241268
result = sql_execute(sql, [])
242269
if len(result) == 0:
243270
return
244-
dims = result[0]['dims']
271+
dims = result[0]['dims']
245272
sql = f"""CREATE INDEX "embedding_hnsw_idx_{k_id}" ON embedding USING hnsw ((embedding::vector({dims})) vector_cosine_ops) WHERE knowledge_id = '{k_id}'"""
246273
update_execute(sql, [])
247274
maxkb_logger.info(f'Created index for knowledge ID: {k_id}')

apps/knowledge/task/embedding.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
UpdateEmbeddingDocumentIdArgs
1313
from common.utils.logger import maxkb_logger
1414
from knowledge.models import Document, TaskType, State
15-
from knowledge.serializers.common import drop_knowledge_index
15+
from knowledge.serializers.common import drop_knowledge_index, get_embedding_model_default_params
1616
from models_provider.models import Model
1717
from models_provider.tools import get_model
1818
from ops import celery_app
@@ -26,21 +26,9 @@ def get_embedding_model(model_id, exception_handler=lambda e: maxkb_logger.error
2626
try:
2727
model = QuerySet(Model).filter(id=model_id).first()
2828

29-
def convert_to_int(value):
30-
if isinstance(value, str):
31-
try:
32-
return int(value)
33-
except ValueError:
34-
return value
35-
return value
36-
37-
s = {
38-
p.get('field'): convert_to_int(p.get('default_value'))
39-
for p in model.model_params_form
40-
if p.get('default_value') is not None
41-
}
42-
43-
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**s}))
29+
default_params = get_embedding_model_default_params(model)
30+
31+
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params}))
4432
except Exception as e:
4533
exception_handler(e)
4634
raise e

0 commit comments

Comments
 (0)