11import concurrent .futures
22import logging
33from concurrent .futures import ThreadPoolExecutor
4- from typing import Any
4+ from typing import Any , NotRequired
55
66from flask import Flask , current_app
77from sqlalchemy import select
88from sqlalchemy .orm import Session , load_only
9+ from typing_extensions import TypedDict
910
1011from configs import dify_config
1112from core .db .session_factory import session_factory
1213from core .model_manager import ModelManager
13- from core .rag .data_post_processor .data_post_processor import DataPostProcessor
14+ from core .rag .data_post_processor .data_post_processor import DataPostProcessor , RerankingModelDict , WeightsDict
1415from core .rag .datasource .keyword .keyword_factory import Keyword
1516from core .rag .datasource .vdb .vector_factory import Vector
16- from core .rag .embedding .retrieval import RetrievalChildChunk , RetrievalSegments
17+ from core .rag .embedding .retrieval import AttachmentInfoDict , RetrievalChildChunk , RetrievalSegments
1718from core .rag .entities .metadata_entities import MetadataCondition
1819from core .rag .index_processor .constant .doc_type import DocType
1920from core .rag .index_processor .constant .index_type import IndexStructureType
3536from models .model import UploadFile
3637from services .external_knowledge_service import ExternalDatasetService
3738
38- default_retrieval_model = {
39+
40+ class SegmentAttachmentResult (TypedDict ):
41+ attachment_info : AttachmentInfoDict
42+ segment_id : str
43+
44+
45+ class SegmentAttachmentInfoResult (TypedDict ):
46+ attachment_id : str
47+ attachment_info : AttachmentInfoDict
48+ segment_id : str
49+
50+
51+ class ChildChunkDetail (TypedDict ):
52+ id : str
53+ content : str
54+ position : int
55+ score : float
56+
57+
58+ class SegmentChildMapDetail (TypedDict ):
59+ max_score : float
60+ child_chunks : list [ChildChunkDetail ]
61+
62+
63+ class SegmentRecord (TypedDict ):
64+ segment : DocumentSegment
65+ score : NotRequired [float ]
66+ child_chunks : NotRequired [list [ChildChunkDetail ]]
67+ files : NotRequired [list [AttachmentInfoDict ]]
68+
69+
70+ class DefaultRetrievalModelDict (TypedDict ):
71+ search_method : RetrievalMethod | str
72+ reranking_enable : bool
73+ reranking_model : RerankingModelDict
74+ top_k : int
75+ score_threshold_enabled : bool
76+
77+
78+ default_retrieval_model : DefaultRetrievalModelDict = {
3979 "search_method" : RetrievalMethod .SEMANTIC_SEARCH ,
4080 "reranking_enable" : False ,
4181 "reranking_model" : {"reranking_provider_name" : "" , "reranking_model_name" : "" },
@@ -56,9 +96,9 @@ def retrieve(
5696 query : str ,
5797 top_k : int = 4 ,
5898 score_threshold : float | None = 0.0 ,
59- reranking_model : dict | None = None ,
99+ reranking_model : RerankingModelDict | None = None ,
60100 reranking_mode : str = "reranking_model" ,
61- weights : dict | None = None ,
101+ weights : WeightsDict | None = None ,
62102 document_ids_filter : list [str ] | None = None ,
63103 attachment_ids : list | None = None ,
64104 ):
@@ -235,7 +275,7 @@ def embedding_search(
235275 query : str ,
236276 top_k : int ,
237277 score_threshold : float | None ,
238- reranking_model : dict | None ,
278+ reranking_model : RerankingModelDict | None ,
239279 all_documents : list ,
240280 retrieval_method : RetrievalMethod ,
241281 exceptions : list ,
@@ -277,8 +317,8 @@ def embedding_search(
277317 if documents :
278318 if (
279319 reranking_model
280- and reranking_model . get ( "reranking_model_name" )
281- and reranking_model . get ( "reranking_provider_name" )
320+ and reranking_model [ "reranking_model_name" ]
321+ and reranking_model [ "reranking_provider_name" ]
282322 and retrieval_method == RetrievalMethod .SEMANTIC_SEARCH
283323 ):
284324 data_post_processor = DataPostProcessor (
@@ -288,8 +328,8 @@ def embedding_search(
288328 model_manager = ModelManager ()
289329 is_support_vision = model_manager .check_model_support_vision (
290330 tenant_id = dataset .tenant_id ,
291- provider = reranking_model . get ( "reranking_provider_name" ) or "" ,
292- model = reranking_model . get ( "reranking_model_name" ) or "" ,
331+ provider = reranking_model [ "reranking_provider_name" ] ,
332+ model = reranking_model [ "reranking_model_name" ] ,
293333 model_type = ModelType .RERANK ,
294334 )
295335 if is_support_vision :
@@ -329,7 +369,7 @@ def full_text_index_search(
329369 query : str ,
330370 top_k : int ,
331371 score_threshold : float | None ,
332- reranking_model : dict | None ,
372+ reranking_model : RerankingModelDict | None ,
333373 all_documents : list ,
334374 retrieval_method : str ,
335375 exceptions : list ,
@@ -349,8 +389,8 @@ def full_text_index_search(
349389 if documents :
350390 if (
351391 reranking_model
352- and reranking_model . get ( "reranking_model_name" )
353- and reranking_model . get ( "reranking_provider_name" )
392+ and reranking_model [ "reranking_model_name" ]
393+ and reranking_model [ "reranking_provider_name" ]
354394 and retrieval_method == RetrievalMethod .FULL_TEXT_SEARCH
355395 ):
356396 data_post_processor = DataPostProcessor (
@@ -459,7 +499,7 @@ def format_retrieval_documents(cls, documents: list[Document]) -> list[Retrieval
459499 segment_ids : list [str ] = []
460500 index_node_segments : list [DocumentSegment ] = []
461501 segments : list [DocumentSegment ] = []
462- attachment_map : dict [str , list [dict [ str , Any ] ]] = {}
502+ attachment_map : dict [str , list [AttachmentInfoDict ]] = {}
463503 child_chunk_map : dict [str , list [ChildChunk ]] = {}
464504 doc_segment_map : dict [str , list [str ]] = {}
465505 segment_summary_map : dict [str , str ] = {} # Map segment_id to summary content
@@ -544,12 +584,12 @@ def format_retrieval_documents(cls, documents: list[Document]) -> list[Retrieval
544584 segment_summary_map [summary .chunk_id ] = summary .summary_content
545585
546586 include_segment_ids = set ()
547- segment_child_map : dict [str , dict [ str , Any ] ] = {}
548- records : list [dict [ str , Any ] ] = []
587+ segment_child_map : dict [str , SegmentChildMapDetail ] = {}
588+ records : list [SegmentRecord ] = []
549589
550590 for segment in segments :
551591 child_chunks : list [ChildChunk ] = child_chunk_map .get (segment .id , [])
552- attachment_infos : list [dict [ str , Any ] ] = attachment_map .get (segment .id , [])
592+ attachment_infos : list [AttachmentInfoDict ] = attachment_map .get (segment .id , [])
553593 ds_dataset_document : DatasetDocument | None = valid_dataset_documents .get (segment .document_id )
554594
555595 if ds_dataset_document and ds_dataset_document .doc_form == IndexStructureType .PARENT_CHILD_INDEX :
@@ -560,14 +600,14 @@ def format_retrieval_documents(cls, documents: list[Document]) -> list[Retrieval
560600 max_score = summary_score_map .get (segment .id , 0.0 )
561601
562602 if child_chunks or attachment_infos :
563- child_chunk_details = []
603+ child_chunk_details : list [ ChildChunkDetail ] = []
564604 for child_chunk in child_chunks :
565605 child_document : Document | None = doc_to_document_map .get (child_chunk .index_node_id )
566606 if child_document :
567607 child_score = child_document .metadata .get ("score" , 0.0 )
568608 else :
569609 child_score = 0.0
570- child_chunk_detail = {
610+ child_chunk_detail : ChildChunkDetail = {
571611 "id" : child_chunk .id ,
572612 "content" : child_chunk .content ,
573613 "position" : child_chunk .position ,
@@ -580,7 +620,7 @@ def format_retrieval_documents(cls, documents: list[Document]) -> list[Retrieval
580620 if file_document :
581621 max_score = max (max_score , file_document .metadata .get ("score" , 0.0 ))
582622
583- map_detail = {
623+ map_detail : SegmentChildMapDetail = {
584624 "max_score" : max_score ,
585625 "child_chunks" : child_chunk_details ,
586626 }
@@ -593,7 +633,7 @@ def format_retrieval_documents(cls, documents: list[Document]) -> list[Retrieval
593633 "max_score" : summary_score ,
594634 "child_chunks" : [],
595635 }
596- record : dict [ str , Any ] = {
636+ record : SegmentRecord = {
597637 "segment" : segment ,
598638 }
599639 records .append (record )
@@ -617,19 +657,19 @@ def format_retrieval_documents(cls, documents: list[Document]) -> list[Retrieval
617657 if file_doc :
618658 max_score = max (max_score , file_doc .metadata .get ("score" , 0.0 ))
619659
620- record = {
660+ another_record : SegmentRecord = {
621661 "segment" : segment ,
622662 "score" : max_score ,
623663 }
624- records .append (record )
664+ records .append (another_record )
625665
626666 # Add child chunks information to records
627667 for record in records :
628668 if record ["segment" ].id in segment_child_map :
629- record ["child_chunks" ] = segment_child_map [record ["segment" ].id ]. get ( "child_chunks" ) # type: ignore
630- record ["score" ] = segment_child_map [record ["segment" ].id ]["max_score" ] # type: ignore
669+ record ["child_chunks" ] = segment_child_map [record ["segment" ].id ][ "child_chunks" ]
670+ record ["score" ] = segment_child_map [record ["segment" ].id ]["max_score" ]
631671 if record ["segment" ].id in attachment_map :
632- record ["files" ] = attachment_map [record ["segment" ].id ] # type: ignore[assignment]
672+ record ["files" ] = attachment_map [record ["segment" ].id ]
633673
634674 result : list [RetrievalSegments ] = []
635675 for record in records :
@@ -693,9 +733,9 @@ def _retrieve(
693733 query : str | None = None ,
694734 top_k : int = 4 ,
695735 score_threshold : float | None = 0.0 ,
696- reranking_model : dict | None = None ,
736+ reranking_model : RerankingModelDict | None = None ,
697737 reranking_mode : str = "reranking_model" ,
698- weights : dict | None = None ,
738+ weights : WeightsDict | None = None ,
699739 document_ids_filter : list [str ] | None = None ,
700740 attachment_id : str | None = None ,
701741 ):
@@ -807,7 +847,7 @@ def _retrieve(
807847 @classmethod
808848 def get_segment_attachment_info (
809849 cls , dataset_id : str , tenant_id : str , attachment_id : str , session : Session
810- ) -> dict [ str , Any ] | None :
850+ ) -> SegmentAttachmentResult | None :
811851 upload_file = session .query (UploadFile ).where (UploadFile .id == attachment_id ).first ()
812852 if upload_file :
813853 attachment_binding = (
@@ -816,7 +856,7 @@ def get_segment_attachment_info(
816856 .first ()
817857 )
818858 if attachment_binding :
819- attachment_info = {
859+ attachment_info : AttachmentInfoDict = {
820860 "id" : upload_file .id ,
821861 "name" : upload_file .name ,
822862 "extension" : "." + upload_file .extension ,
@@ -828,8 +868,10 @@ def get_segment_attachment_info(
828868 return None
829869
830870 @classmethod
831- def get_segment_attachment_infos (cls , attachment_ids : list [str ], session : Session ) -> list [dict [str , Any ]]:
832- attachment_infos = []
871+ def get_segment_attachment_infos (
872+ cls , attachment_ids : list [str ], session : Session
873+ ) -> list [SegmentAttachmentInfoResult ]:
874+ attachment_infos : list [SegmentAttachmentInfoResult ] = []
833875 upload_files = session .query (UploadFile ).where (UploadFile .id .in_ (attachment_ids )).all ()
834876 if upload_files :
835877 upload_file_ids = [upload_file .id for upload_file in upload_files ]
@@ -843,7 +885,7 @@ def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Sessio
843885 if attachment_bindings :
844886 for upload_file in upload_files :
845887 attachment_binding = attachment_binding_map .get (upload_file .id )
846- attachment_info = {
888+ info : AttachmentInfoDict = {
847889 "id" : upload_file .id ,
848890 "name" : upload_file .name ,
849891 "extension" : "." + upload_file .extension ,
@@ -855,7 +897,7 @@ def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Sessio
855897 attachment_infos .append (
856898 {
857899 "attachment_id" : attachment_binding .attachment_id ,
858- "attachment_info" : attachment_info ,
900+ "attachment_info" : info ,
859901 "segment_id" : attachment_binding .segment_id ,
860902 }
861903 )
0 commit comments