Skip to content

Commit d0cd6d6

Browse files
committed
refactor: replace get_embedding_model_default_params with get_model_default_params
1 parent 1b08643 commit d0cd6d6

File tree

4 files changed

+25
-26
lines changed

4 files changed

+25
-26
lines changed

apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222
from common.utils.common import get_file_content
2323
from knowledge.models import Paragraph, Knowledge
2424
from knowledge.models import SearchMode
25-
from knowledge.serializers.common import get_embedding_model_default_params
2625
from maxkb.conf import PROJECT_DIR
2726
from models_provider.models import Model
28-
from models_provider.tools import get_model, get_model_by_id
27+
from models_provider.tools import get_model, get_model_by_id, get_model_default_params
2928

3029

3130
def reset_meta(meta):
@@ -65,7 +64,7 @@ def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_docum
6564
if model.model_type != "EMBEDDING":
6665
raise Exception(_("Model does not exist"))
6766
self.context['model_name'] = model.name
68-
default_params = get_embedding_model_default_params(model)
67+
default_params = get_model_default_params(model)
6968
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params}))
7069
embedding_value = embedding_model.embed_query(exec_problem_text)
7170
vector = VectorStore.get_embedding_vector()

apps/knowledge/serializers/common.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from knowledge.models import Document
2727
from knowledge.models import Paragraph, Problem, ProblemParagraphMapping, Knowledge, File
2828
from maxkb.conf import PROJECT_DIR
29-
from models_provider.tools import get_model
29+
from models_provider.tools import get_model, get_model_default_params
3030

3131

3232
class MetaSerializer(serializers.Serializer):
@@ -112,21 +112,6 @@ def to_problem_model_list(self):
112112
], problem_paragraph_mapping_list
113113
return result
114114

115-
def get_embedding_model_default_params(model):
116-
def convert_to_int(value):
117-
if isinstance(value, str):
118-
try:
119-
return int(value)
120-
except ValueError:
121-
return value
122-
return value
123-
124-
return {
125-
p.get('field'): convert_to_int(p.get('default_value'))
126-
for p in model.model_params_form
127-
if p.get('default_value') is not None
128-
}
129-
130115

131116
def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
132117
knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
@@ -135,7 +120,7 @@ def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
135120
if len(knowledge_list) == 0:
136121
raise Exception(_('Knowledge base setting error, please reset the knowledge base'))
137122

138-
default_params = get_embedding_model_default_params(knowledge_list[0].embedding_model)
123+
default_params = get_model_default_params(knowledge_list[0].embedding_model)
139124

140125
return ModelManage.get_model(
141126
str(knowledge_list[0].embedding_model_id),
@@ -146,14 +131,14 @@ def get_embedding_model_by_knowledge_id_list(knowledge_id_list: List):
146131
def get_embedding_model_by_knowledge_id(knowledge_id: str):
147132
knowledge = QuerySet(Knowledge).select_related('embedding_model').filter(id=knowledge_id).first()
148133

149-
default_params = get_embedding_model_default_params(knowledge.embedding_model)
134+
default_params = get_model_default_params(knowledge.embedding_model)
150135

151136
return ModelManage.get_model(str(knowledge.embedding_model_id),
152137
lambda _id: get_model(knowledge.embedding_model, **{**default_params}))
153138

154139

155140
def get_embedding_model_by_knowledge(knowledge):
156-
default_params = get_embedding_model_default_params(knowledge.embedding_model)
141+
default_params = get_model_default_params(knowledge.embedding_model)
157142

158143
return ModelManage.get_model(str(knowledge.embedding_model_id),
159144
lambda _id: get_model(knowledge.embedding_model, **{**default_params}))

apps/knowledge/task/embedding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
UpdateEmbeddingDocumentIdArgs
1313
from common.utils.logger import maxkb_logger
1414
from knowledge.models import Document, TaskType, State
15-
from knowledge.serializers.common import drop_knowledge_index, get_embedding_model_default_params
15+
from knowledge.serializers.common import drop_knowledge_index
1616
from models_provider.models import Model
17-
from models_provider.tools import get_model
17+
from models_provider.tools import get_model, get_model_default_params
1818
from ops import celery_app
1919

2020

@@ -26,7 +26,7 @@ def get_embedding_model(model_id, exception_handler=lambda e: maxkb_logger.error
2626
try:
2727
model = QuerySet(Model).filter(id=model_id).first()
2828

29-
default_params = get_embedding_model_default_params(model)
29+
default_params = get_model_default_params(model)
3030

3131
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**default_params}))
3232
except Exception as e:

apps/models_provider/tools.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,21 @@ def get_model_by_id(_id, workspace_id):
115115
raise Exception(_("Model does not exist"))
116116
return model
117117

118+
def get_model_default_params(model):
119+
def convert_to_int(value):
120+
if isinstance(value, str):
121+
try:
122+
return int(value)
123+
except ValueError:
124+
return value
125+
return value
126+
127+
return {
128+
p.get('field'): convert_to_int(p.get('default_value'))
129+
for p in model.model_params_form
130+
if p.get('default_value') is not None
131+
}
132+
118133

119134
def get_model_instance_by_model_workspace_id(model_id, workspace_id, **kwargs):
120135
"""
@@ -124,5 +139,5 @@ def get_model_instance_by_model_workspace_id(model_id, workspace_id, **kwargs):
124139
@return: 模型实例
125140
"""
126141
model = get_model_by_id(model_id, workspace_id)
127-
s = {p.get('field'): p.get('default_value') for p in model.model_params_form if p.get('default_value') is not None}
142+
s = get_model_default_params(model)
128143
return ModelManage.get_model(model_id, lambda _id: get_model(model, **{**s, **kwargs}))

0 commit comments

Comments
 (0)