Skip to content

Commit 25d4e3d

Browse files
committed
chore: refactor embedding model parameter handling
1 parent 18543a0 commit 25d4e3d

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

apps/knowledge/task/embedding.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# coding=utf-8
22

3-
import logging
43
import traceback
54
from typing import List
65

@@ -14,12 +13,11 @@
1413
from common.utils.logger import maxkb_logger
1514
from knowledge.models import Document, TaskType, State
1615
from knowledge.serializers.common import drop_knowledge_index
17-
from models_provider.tools import get_model
1816
from models_provider.models import Model
17+
from models_provider.tools import get_model
1918
from ops import celery_app
2019

2120

22-
2321
def get_embedding_model(model_id, exception_handler=lambda e: maxkb_logger.error(
2422
_('Failed to obtain vector model: {error} {traceback}').format(
2523
error=str(e),
@@ -28,7 +26,20 @@ def get_embedding_model(model_id, exception_handler=lambda e: maxkb_logger.error
2826
try:
2927
model = QuerySet(Model).filter(id=model_id).first()
3028

31-
s = {p.get('field'): p.get('default_value') for p in model.model_params_form if p.get('default_value') is not None}
29+
def convert_to_int(value):
30+
if isinstance(value, str):
31+
try:
32+
return int(value)
33+
except ValueError:
34+
return value
35+
return value
36+
37+
s = {
38+
p.get('field'): convert_to_int(p.get('default_value'))
39+
for p in model.model_params_form
40+
if p.get('default_value') is not None
41+
}
42+
3243
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model, **{**s}))
3344
except Exception as e:
3445
exception_handler(e)

0 commit comments

Comments
 (0)