Skip to content

Commit b6f5a83

Browse files
committed
feat: add HitTest API for knowledge base query testing and update SyncWeb API to use knowledge_id
1 parent 7dcd1a7 commit b6f5a83

File tree

4 files changed

+100
-10
lines changed

4 files changed

+100
-10
lines changed

apps/knowledge/api/knowledge.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from common.result import ResultSerializer, DefaultResultSerializer
66
from knowledge.serializers.common import GenerateRelatedSerializer
77
from knowledge.serializers.knowledge import KnowledgeBaseCreateRequest, KnowledgeModelSerializer, KnowledgeEditRequest, \
8-
KnowledgeWebCreateRequest
8+
KnowledgeWebCreateRequest, HitTestSerializer
99

1010

1111
class KnowledgeCreateResponse(ResultSerializer):
@@ -238,3 +238,9 @@ class GenerateRelatedAPI(SyncWebAPI):
238238
@staticmethod
239239
def get_request():
240240
return GenerateRelatedSerializer
241+
242+
243+
class HitTestAPI(SyncWebAPI):
244+
@staticmethod
245+
def get_request():
246+
return HitTestSerializer

apps/knowledge/serializers/knowledge.py

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from django.utils.translation import gettext_lazy as _
1515
from rest_framework import serializers
1616

17+
from common.config.embedding_config import VectorStore
1718
from common.db.search import native_search, get_dynamics_model, native_page_search
1819
from common.db.sql_execute import select_list
1920
from common.event import ListenerManagement
@@ -22,9 +23,9 @@
2223
from common.utils.fork import Fork, ChildLink
2324
from common.utils.split_model import get_split_model
2425
from knowledge.models import Knowledge, KnowledgeScope, KnowledgeType, Document, Paragraph, Problem, \
25-
ProblemParagraphMapping, ApplicationKnowledgeMapping, TaskType, State
26+
ProblemParagraphMapping, ApplicationKnowledgeMapping, TaskType, State, SearchMode
2627
from knowledge.serializers.common import ProblemParagraphManage, get_embedding_model_id_by_knowledge_id, MetaSerializer, \
27-
GenerateRelatedSerializer
28+
GenerateRelatedSerializer, get_embedding_model_by_knowledge_id, list_paragraph
2829
from knowledge.serializers.document import DocumentSerializers
2930
from knowledge.task.embedding import embedding_by_knowledge, delete_embedding_by_knowledge
3031
from knowledge.task.generate import generate_related_by_knowledge_id
@@ -79,6 +80,14 @@ def is_valid(self, *, knowledge: Knowledge = None):
7980
valid_class = knowledge_meta_valid_map.get(knowledge.type)
8081
valid_class(data=self.data.get('meta')).is_valid(raise_exception=True)
8182

83+
class HitTestSerializer(serializers.Serializer):
84+
query_text = serializers.CharField(required=True, label=_('query text'))
85+
top_number = serializers.IntegerField(required=True, max_value=10000, min_value=1, label=_("top number"))
86+
similarity = serializers.FloatField(required=True, max_value=2, min_value=0, label=_('similarity'))
87+
search_mode = serializers.CharField(required=True, label=_('search mode'), validators=[
88+
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
89+
message=_('The type only supports embedding|keywords|blend'), code=500)
90+
])
8291

8392
class KnowledgeSerializer(serializers.Serializer):
8493
class Query(serializers.Serializer):
@@ -152,7 +161,7 @@ def generate_related(self, instance: Dict, with_valid=True):
152161
if with_valid:
153162
self.is_valid(raise_exception=True)
154163
GenerateRelatedSerializer(data=instance).is_valid(raise_exception=True)
155-
knowledge_id = self.data.get('id')
164+
knowledge_id = self.data.get('knowledge_id')
156165
model_id = instance.get("model_id")
157166
prompt = instance.get("prompt")
158167
state_list = instance.get('state_list')
@@ -382,15 +391,16 @@ def save_web(self, instance: Dict, with_valid=True):
382391
return {**KnowledgeModelSerializer(knowledge).data, 'document_list': []}
383392

384393
class SyncWeb(serializers.Serializer):
385-
id = serializers.CharField(required=True, label=_('knowledge id'))
394+
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
395+
knowledge_id = serializers.CharField(required=True, label=_('knowledge id'))
386396
user_id = serializers.UUIDField(required=False, label=_('user id'))
387397
sync_type = serializers.CharField(required=True, label=_('sync type'), validators=[
388398
validators.RegexValidator(regex=re.compile("^replace|complete$"),
389399
message=_('The synchronization type only supports:replace|complete'), code=500)])
390400

391401
def is_valid(self, *, raise_exception=False):
392402
super().is_valid(raise_exception=True)
393-
first = QuerySet(Knowledge).filter(id=self.data.get("id")).first()
403+
first = QuerySet(Knowledge).filter(id=self.data.get("knowledge_id")).first()
394404
if first is None:
395405
raise AppApiException(300, _('id does not exist'))
396406
if first.type != KnowledgeType.WEB:
@@ -400,7 +410,7 @@ def sync(self, with_valid=True):
400410
if with_valid:
401411
self.is_valid(raise_exception=True)
402412
sync_type = self.data.get('sync_type')
403-
knowledge_id = self.data.get('id')
413+
knowledge_id = self.data.get('knowledge_id')
404414
knowledge = QuerySet(Knowledge).get(id=knowledge_id)
405415
self.__getattribute__(sync_type + '_sync')(knowledge)
406416
return True
@@ -454,6 +464,52 @@ def complete_sync(self, knowledge):
454464
# 删除段落
455465
QuerySet(Paragraph).filter(knowledge=knowledge).delete()
456466
# 删除向量
457-
delete_embedding_by_knowledge(self.data.get('id'))
467+
delete_embedding_by_knowledge(self.data.get('knowledge_id'))
458468
# 同步
459469
self.replace_sync(knowledge)
470+
471+
class HitTest(serializers.Serializer):
472+
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
473+
knowledge_id = serializers.UUIDField(required=True, label=_("id"))
474+
user_id = serializers.UUIDField(required=False, label=_('user id'))
475+
query_text = serializers.CharField(required=True, label=_('query text'))
476+
top_number = serializers.IntegerField(required=True, max_value=10000, min_value=1, label=_("top number"))
477+
similarity = serializers.FloatField(required=True, max_value=2, min_value=0, label=_('similarity'))
478+
search_mode = serializers.CharField(required=True, label=_('search mode'), validators=[
479+
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
480+
message=_('The type only supports embedding|keywords|blend'), code=500)
481+
])
482+
483+
def is_valid(self, *, raise_exception=True):
484+
super().is_valid(raise_exception=True)
485+
if not QuerySet(Knowledge).filter(id=self.data.get("knowledge_id")).exists():
486+
raise AppApiException(300, _('id does not exist'))
487+
488+
def hit_test(self):
489+
self.is_valid()
490+
vector = VectorStore.get_embedding_vector()
491+
exclude_document_id_list = [
492+
str(
493+
document.id
494+
) for document in QuerySet(Document).filter(knowledge_id=self.data.get('knowledge_id'), is_active=False)
495+
]
496+
model = get_embedding_model_by_knowledge_id(self.data.get('knowledge_id'))
497+
# 向量库检索
498+
hit_list = vector.hit_test(
499+
self.data.get('query_text'),
500+
[self.data.get('knowledge_id')],
501+
exclude_document_id_list,
502+
self.data.get('top_number'),
503+
self.data.get('similarity'),
504+
SearchMode(self.data.get('search_mode')),
505+
model
506+
)
507+
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
508+
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
509+
return [
510+
{
511+
**p,
512+
'similarity': hit_dict.get(p.get('id')).get('similarity'),
513+
'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')
514+
} for p in p_list
515+
]

apps/knowledge/urls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>', views.KnowledgeView.Operate.as_view()),
1111
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/sync', views.KnowledgeView.SyncWeb.as_view()),
1212
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/generate_related', views.KnowledgeView.GenerateRelated.as_view()),
13+
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/hit_test', views.KnowledgeView.HitTest.as_view()),
1314
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document', views.DocumentView.as_view()),
1415
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()),
1516
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split_pattern', views.DocumentView.SplitPattern.as_view()),

apps/knowledge/views/knowledge.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from common.constants.permission_constants import PermissionConstants
99
from common.result import result
1010
from knowledge.api.knowledge import KnowledgeBaseCreateAPI, KnowledgeWebCreateAPI, KnowledgeTreeReadAPI, \
11-
KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI, SyncWebAPI, GenerateRelatedAPI
11+
KnowledgeEditAPI, KnowledgeReadAPI, KnowledgePageAPI, SyncWebAPI, GenerateRelatedAPI, HitTestAPI
1212
from knowledge.serializers.knowledge import KnowledgeSerializer
1313

1414

@@ -128,11 +128,38 @@ def put(self, request: Request, workspace_id: str, knowledge_id: str):
128128
data={
129129
'workspace_id': workspace_id,
130130
'sync_type': request.query_params.get('sync_type'),
131-
'id': knowledge_id,
131+
'knowledge_id': knowledge_id,
132132
'user_id': str(request.user.id)
133133
}
134134
).sync())
135135

136+
class HitTest(APIView):
137+
authentication_classes = [TokenAuth]
138+
139+
@extend_schema(
140+
methods=['PUT'],
141+
summary=_('Hit test list'),
142+
description=_('Hit test list'),
143+
operation_id=_('Hit test list'),
144+
parameters=HitTestAPI.get_parameters(),
145+
request=HitTestAPI.get_request(),
146+
responses=HitTestAPI.get_response(),
147+
tags=[_('Knowledge Base')]
148+
)
149+
@has_permissions(PermissionConstants.KNOWLEDGE_EDIT.get_workspace_permission())
150+
def put(self, request: Request, workspace_id: str, knowledge_id: str):
151+
return result.success(KnowledgeSerializer.HitTest(
152+
data={
153+
'workspace_id': workspace_id,
154+
'knowledge_id': knowledge_id,
155+
'user_id': request.user.id,
156+
"query_text": request.query_params.get("query_text"),
157+
"top_number": request.query_params.get("top_number"),
158+
'similarity': request.query_params.get('similarity'),
159+
'search_mode': request.query_params.get('search_mode')
160+
}
161+
).hit_test())
162+
136163
class GenerateRelated(APIView):
137164
authentication_classes = [TokenAuth]
138165

0 commit comments

Comments
 (0)