1414from django .utils .translation import gettext_lazy as _
1515from rest_framework import serializers
1616
17+ from common .config .embedding_config import VectorStore
1718from common .db .search import native_search , get_dynamics_model , native_page_search
1819from common .db .sql_execute import select_list
1920from common .event import ListenerManagement
2223from common .utils .fork import Fork , ChildLink
2324from common .utils .split_model import get_split_model
2425from knowledge .models import Knowledge , KnowledgeScope , KnowledgeType , Document , Paragraph , Problem , \
25- ProblemParagraphMapping , ApplicationKnowledgeMapping , TaskType , State
26+ ProblemParagraphMapping , ApplicationKnowledgeMapping , TaskType , State , SearchMode
2627from knowledge .serializers .common import ProblemParagraphManage , get_embedding_model_id_by_knowledge_id , MetaSerializer , \
27- GenerateRelatedSerializer
28+ GenerateRelatedSerializer , get_embedding_model_by_knowledge_id , list_paragraph
2829from knowledge .serializers .document import DocumentSerializers
2930from knowledge .task .embedding import embedding_by_knowledge , delete_embedding_by_knowledge
3031from knowledge .task .generate import generate_related_by_knowledge_id
@@ -79,6 +80,14 @@ def is_valid(self, *, knowledge: Knowledge = None):
7980 valid_class = knowledge_meta_valid_map .get (knowledge .type )
8081 valid_class (data = self .data .get ('meta' )).is_valid (raise_exception = True )
8182
83+ class HitTestSerializer (serializers .Serializer ):
84+ query_text = serializers .CharField (required = True , label = _ ('query text' ))
85+ top_number = serializers .IntegerField (required = True , max_value = 10000 , min_value = 1 , label = _ ("top number" ))
86+ similarity = serializers .FloatField (required = True , max_value = 2 , min_value = 0 , label = _ ('similarity' ))
87+ search_mode = serializers .CharField (required = True , label = _ ('search mode' ), validators = [
88+ validators .RegexValidator (regex = re .compile ("^embedding|keywords|blend$" ),
89+ message = _ ('The type only supports embedding|keywords|blend' ), code = 500 )
90+ ])
8291
8392class KnowledgeSerializer (serializers .Serializer ):
8493 class Query (serializers .Serializer ):
@@ -152,7 +161,7 @@ def generate_related(self, instance: Dict, with_valid=True):
152161 if with_valid :
153162 self .is_valid (raise_exception = True )
154163 GenerateRelatedSerializer (data = instance ).is_valid (raise_exception = True )
155- knowledge_id = self .data .get ('id ' )
164+ knowledge_id = self .data .get ('knowledge_id ' )
156165 model_id = instance .get ("model_id" )
157166 prompt = instance .get ("prompt" )
158167 state_list = instance .get ('state_list' )
@@ -382,15 +391,16 @@ def save_web(self, instance: Dict, with_valid=True):
382391 return {** KnowledgeModelSerializer (knowledge ).data , 'document_list' : []}
383392
384393 class SyncWeb (serializers .Serializer ):
385- id = serializers .CharField (required = True , label = _ ('knowledge id' ))
394+ workspace_id = serializers .CharField (required = True , label = _ ('workspace id' ))
395+ knowledge_id = serializers .CharField (required = True , label = _ ('knowledge id' ))
386396 user_id = serializers .UUIDField (required = False , label = _ ('user id' ))
387397 sync_type = serializers .CharField (required = True , label = _ ('sync type' ), validators = [
388398 validators .RegexValidator (regex = re .compile ("^replace|complete$" ),
389399 message = _ ('The synchronization type only supports:replace|complete' ), code = 500 )])
390400
391401 def is_valid (self , * , raise_exception = False ):
392402 super ().is_valid (raise_exception = True )
393- first = QuerySet (Knowledge ).filter (id = self .data .get ("id " )).first ()
403+ first = QuerySet (Knowledge ).filter (id = self .data .get ("knowledge_id " )).first ()
394404 if first is None :
395405 raise AppApiException (300 , _ ('id does not exist' ))
396406 if first .type != KnowledgeType .WEB :
@@ -400,7 +410,7 @@ def sync(self, with_valid=True):
400410 if with_valid :
401411 self .is_valid (raise_exception = True )
402412 sync_type = self .data .get ('sync_type' )
403- knowledge_id = self .data .get ('id ' )
413+ knowledge_id = self .data .get ('knowledge_id ' )
404414 knowledge = QuerySet (Knowledge ).get (id = knowledge_id )
405415 self .__getattribute__ (sync_type + '_sync' )(knowledge )
406416 return True
@@ -454,6 +464,52 @@ def complete_sync(self, knowledge):
454464 # 删除段落
455465 QuerySet (Paragraph ).filter (knowledge = knowledge ).delete ()
456466 # 删除向量
457- delete_embedding_by_knowledge (self .data .get ('id ' ))
467+ delete_embedding_by_knowledge (self .data .get ('knowledge_id ' ))
458468 # 同步
459469 self .replace_sync (knowledge )
470+
471+ class HitTest (serializers .Serializer ):
472+ workspace_id = serializers .CharField (required = True , label = _ ('workspace id' ))
473+ knowledge_id = serializers .UUIDField (required = True , label = _ ("id" ))
474+ user_id = serializers .UUIDField (required = False , label = _ ('user id' ))
475+ query_text = serializers .CharField (required = True , label = _ ('query text' ))
476+ top_number = serializers .IntegerField (required = True , max_value = 10000 , min_value = 1 , label = _ ("top number" ))
477+ similarity = serializers .FloatField (required = True , max_value = 2 , min_value = 0 , label = _ ('similarity' ))
478+ search_mode = serializers .CharField (required = True , label = _ ('search mode' ), validators = [
479+ validators .RegexValidator (regex = re .compile ("^embedding|keywords|blend$" ),
480+ message = _ ('The type only supports embedding|keywords|blend' ), code = 500 )
481+ ])
482+
483+ def is_valid (self , * , raise_exception = True ):
484+ super ().is_valid (raise_exception = True )
485+ if not QuerySet (Knowledge ).filter (id = self .data .get ("knowledge_id" )).exists ():
486+ raise AppApiException (300 , _ ('id does not exist' ))
487+
488+ def hit_test (self ):
489+ self .is_valid ()
490+ vector = VectorStore .get_embedding_vector ()
491+ exclude_document_id_list = [
492+ str (
493+ document .id
494+ ) for document in QuerySet (Document ).filter (knowledge_id = self .data .get ('knowledge_id' ), is_active = False )
495+ ]
496+ model = get_embedding_model_by_knowledge_id (self .data .get ('knowledge_id' ))
497+ # 向量库检索
498+ hit_list = vector .hit_test (
499+ self .data .get ('query_text' ),
500+ [self .data .get ('knowledge_id' )],
501+ exclude_document_id_list ,
502+ self .data .get ('top_number' ),
503+ self .data .get ('similarity' ),
504+ SearchMode (self .data .get ('search_mode' )),
505+ model
506+ )
507+ hit_dict = reduce (lambda x , y : {** x , ** y }, [{hit .get ('paragraph_id' ): hit } for hit in hit_list ], {})
508+ p_list = list_paragraph ([h .get ('paragraph_id' ) for h in hit_list ])
509+ return [
510+ {
511+ ** p ,
512+ 'similarity' : hit_dict .get (p .get ('id' )).get ('similarity' ),
513+ 'comprehensive_score' : hit_dict .get (p .get ('id' )).get ('comprehensive_score' )
514+ } for p in p_list
515+ ]
0 commit comments