33from pydantic import BaseModel , Field
44
55from core .rag .datasource .retrieval_service import RetrievalService
6+ from core .rag .entities .context_entities import DocumentContext
67from core .rag .models .document import Document as RetrievalDocument
78from core .rag .retrieval .retrieval_methods import RetrievalMethod
89from core .tools .tool .dataset_retriever .dataset_retriever_base_tool import DatasetRetrieverBaseTool
910from extensions .ext_database import db
10- from models .dataset import Dataset , Document , DocumentSegment
11+ from models .dataset import Dataset
12+ from models .dataset import Document as DatasetDocument
1113from services .external_knowledge_service import ExternalDatasetService
1214
1315default_retrieval_model = {
@@ -54,7 +56,6 @@ def _run(self, query: str) -> str:
5456
5557 if not dataset :
5658 return ""
57-
5859 for hit_callback in self .hit_callbacks :
5960 hit_callback .on_query (query , dataset .id )
6061 if dataset .provider == "external" :
@@ -125,7 +126,6 @@ def _run(self, query: str) -> str:
125126 )
126127 else :
127128 documents = []
128-
129129 for hit_callback in self .hit_callbacks :
130130 hit_callback .on_tool_end (documents )
131131 document_score_list = {}
@@ -134,50 +134,46 @@ def _run(self, query: str) -> str:
134134 if item .metadata is not None and item .metadata .get ("score" ):
135135 document_score_list [item .metadata ["doc_id" ]] = item .metadata ["score" ]
136136 document_context_list = []
137- index_node_ids = [document .metadata ["doc_id" ] for document in documents ]
138- segments = DocumentSegment .query .filter (
139- DocumentSegment .dataset_id == self .dataset_id ,
140- DocumentSegment .completed_at .isnot (None ),
141- DocumentSegment .status == "completed" ,
142- DocumentSegment .enabled == True ,
143- DocumentSegment .index_node_id .in_ (index_node_ids ),
144- ).all ()
145-
146- if segments :
147- index_node_id_to_position = {id : position for position , id in enumerate (index_node_ids )}
148- sorted_segments = sorted (
149- segments , key = lambda segment : index_node_id_to_position .get (segment .index_node_id , float ("inf" ))
150- )
151- for segment in sorted_segments :
137+ records = RetrievalService .format_retrieval_documents (documents )
138+ if records :
139+ for record in records :
140+ segment = record .segment
152141 if segment .answer :
153142 document_context_list .append (
154- f"question:{ segment .get_sign_content ()} answer:{ segment .answer } "
143+ DocumentContext (
144+ content = f"question:{ segment .get_sign_content ()} answer:{ segment .answer } " ,
145+ score = record .score ,
146+ )
155147 )
156148 else :
157- document_context_list .append (segment .get_sign_content ())
149+ document_context_list .append (
150+ DocumentContext (
151+ content = segment .get_sign_content (),
152+ score = record .score ,
153+ )
154+ )
155+ retrieval_resource_list = []
158156 if self .return_resource :
159- context_list = []
160- resource_number = 1
161- for segment in sorted_segments :
162- document_segment = Document .query .filter (
163- Document .id == segment .document_id ,
164- Document .enabled == True ,
165- Document .archived == False ,
157+ for record in records :
158+ segment = record . segment
159+ dataset = Dataset . query . filter_by ( id = segment . dataset_id ). first ()
160+ document = DatasetDocument .query .filter (
161+ DatasetDocument .id == segment .document_id ,
162+ DatasetDocument .enabled == True ,
163+ DatasetDocument .archived == False ,
166164 ).first ()
167- if not document_segment :
168- continue
169- if dataset and document_segment :
165+ if dataset and document :
170166 source = {
171- "position" : resource_number ,
172167 "dataset_id" : dataset .id ,
173168 "dataset_name" : dataset .name ,
174- "document_id" : document_segment .id ,
175- "document_name" : document_segment .name ,
176- "data_source_type" : document_segment .data_source_type ,
169+ "document_id" : document .id , # type: ignore
170+ "document_name" : document .name , # type: ignore
171+ "data_source_type" : document .data_source_type , # type: ignore
177172 "segment_id" : segment .id ,
178173 "retriever_from" : self .retriever_from ,
179- "score" : document_score_list . get ( segment . index_node_id , None ) ,
174+ "score" : record . score or 0.0 ,
180175 }
176+
181177 if self .retriever_from == "dev" :
182178 source ["hit_count" ] = segment .hit_count
183179 source ["word_count" ] = segment .word_count
@@ -187,10 +183,19 @@ def _run(self, query: str) -> str:
187183 source ["content" ] = f"question:{ segment .content } \n answer:{ segment .answer } "
188184 else :
189185 source ["content" ] = segment .content
190- context_list .append (source )
191- resource_number += 1
192-
193- for hit_callback in self .hit_callbacks :
194- hit_callback .return_retriever_resource_info (context_list )
186+ retrieval_resource_list .append (source )
195187
196- return str ("\n " .join (document_context_list ))
188+ if self .return_resource and retrieval_resource_list :
189+ retrieval_resource_list = sorted (
190+ retrieval_resource_list ,
191+ key = lambda x : x .get ("score" ) or 0.0 ,
192+ reverse = True ,
193+ )
194+ for position , item in enumerate (retrieval_resource_list , start = 1 ): # type: ignore
195+ item ["position" ] = position # type: ignore
196+ for hit_callback in self .hit_callbacks :
197+ hit_callback .return_retriever_resource_info (retrieval_resource_list )
198+ if document_context_list :
199+ document_context_list = sorted (document_context_list , key = lambda x : x .score or 0.0 , reverse = True )
200+ return str ("\n " .join ([document_context .content for document_context in document_context_list ]))
201+ return ""
0 commit comments