Skip to content

Commit 485da15

Browse files
refactor(api): replace dict/Mapping with TypedDict in core/rag retrieval_service.py (#33615)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent d7f70f3 commit 485da15

File tree

18 files changed

+165
-71
lines changed

18 files changed

+165
-71
lines changed

api/core/app/app_config/easy_ui_based_app/dataset/manager.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ModelConfig,
99
)
1010
from core.entities.agent_entities import PlanningStrategy
11+
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
1112
from models.model import AppMode, AppModelConfigDict
1213
from services.dataset_service import DatasetService
1314

@@ -117,8 +118,10 @@ def convert(cls, config: AppModelConfigDict) -> DatasetEntity | None:
117118
score_threshold=float(score_threshold_val)
118119
if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None
119120
else None,
120-
reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None,
121-
weights=weights_val if isinstance(weights_val, dict) else None,
121+
reranking_model=cast(RerankingModelDict, reranking_model_val)
122+
if isinstance(reranking_model_val, dict)
123+
else None,
124+
weights=cast(WeightsDict, weights_val) if isinstance(weights_val, dict) else None,
122125
reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)),
123126
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
124127
metadata_filtering_mode=cast(

api/core/app/app_config/entities.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pydantic import BaseModel, Field
66

7+
from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict
78
from dify_graph.file import FileUploadConfig
89
from dify_graph.model_runtime.entities.llm_entities import LLMMode
910
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
@@ -194,8 +195,8 @@ def value_of(cls, value: str):
194195
top_k: int | None = None
195196
score_threshold: float | None = 0.0
196197
rerank_mode: str | None = "reranking_model"
197-
reranking_model: dict | None = None
198-
weights: dict | None = None
198+
reranking_model: RerankingModelDict | None = None
199+
weights: WeightsDict | None = None
199200
reranking_enabled: bool | None = True
200201
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled"
201202
metadata_model_config: ModelConfig | None = None

api/core/rag/data_post_processor/data_post_processor.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing_extensions import TypedDict
2+
13
from core.model_manager import ModelInstance, ModelManager
24
from core.rag.data_post_processor.reorder import ReorderRunner
35
from core.rag.index_processor.constant.query_type import QueryType
@@ -10,15 +12,35 @@
1012
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError
1113

1214

15+
class RerankingModelDict(TypedDict):
16+
reranking_provider_name: str
17+
reranking_model_name: str
18+
19+
20+
class VectorSettingDict(TypedDict):
21+
vector_weight: float
22+
embedding_provider_name: str
23+
embedding_model_name: str
24+
25+
26+
class KeywordSettingDict(TypedDict):
27+
keyword_weight: float
28+
29+
30+
class WeightsDict(TypedDict):
31+
vector_setting: VectorSettingDict
32+
keyword_setting: KeywordSettingDict
33+
34+
1335
class DataPostProcessor:
1436
"""Interface for data post-processing document."""
1537

1638
def __init__(
1739
self,
1840
tenant_id: str,
1941
reranking_mode: str,
20-
reranking_model: dict | None = None,
21-
weights: dict | None = None,
42+
reranking_model: RerankingModelDict | None = None,
43+
weights: WeightsDict | None = None,
2244
reorder_enabled: bool = False,
2345
):
2446
self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights)
@@ -45,8 +67,8 @@ def _get_rerank_runner(
4567
self,
4668
reranking_mode: str,
4769
tenant_id: str,
48-
reranking_model: dict | None = None,
49-
weights: dict | None = None,
70+
reranking_model: RerankingModelDict | None = None,
71+
weights: WeightsDict | None = None,
5072
) -> BaseRerankRunner | None:
5173
if reranking_mode == RerankMode.WEIGHTED_SCORE and weights:
5274
runner = RerankRunnerFactory.create_rerank_runner(
@@ -79,12 +101,14 @@ def _get_reorder_runner(self, reorder_enabled) -> ReorderRunner | None:
79101
return ReorderRunner()
80102
return None
81103

82-
def _get_rerank_model_instance(self, tenant_id: str, reranking_model: dict | None) -> ModelInstance | None:
104+
def _get_rerank_model_instance(
105+
self, tenant_id: str, reranking_model: RerankingModelDict | None
106+
) -> ModelInstance | None:
83107
if reranking_model:
84108
try:
85109
model_manager = ModelManager()
86-
reranking_provider_name = reranking_model.get("reranking_provider_name")
87-
reranking_model_name = reranking_model.get("reranking_model_name")
110+
reranking_provider_name = reranking_model["reranking_provider_name"]
111+
reranking_model_name = reranking_model["reranking_model_name"]
88112
if not reranking_provider_name or not reranking_model_name:
89113
return None
90114
rerank_model_instance = model_manager.get_model_instance(

api/core/rag/datasource/retrieval_service.py

Lines changed: 77 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
import concurrent.futures
22
import logging
33
from concurrent.futures import ThreadPoolExecutor
4-
from typing import Any
4+
from typing import Any, NotRequired
55

66
from flask import Flask, current_app
77
from sqlalchemy import select
88
from sqlalchemy.orm import Session, load_only
9+
from typing_extensions import TypedDict
910

1011
from configs import dify_config
1112
from core.db.session_factory import session_factory
1213
from core.model_manager import ModelManager
13-
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
14+
from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict
1415
from core.rag.datasource.keyword.keyword_factory import Keyword
1516
from core.rag.datasource.vdb.vector_factory import Vector
16-
from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments
17+
from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments
1718
from core.rag.entities.metadata_entities import MetadataCondition
1819
from core.rag.index_processor.constant.doc_type import DocType
1920
from core.rag.index_processor.constant.index_type import IndexStructureType
@@ -35,7 +36,46 @@
3536
from models.model import UploadFile
3637
from services.external_knowledge_service import ExternalDatasetService
3738

38-
default_retrieval_model = {
39+
40+
class SegmentAttachmentResult(TypedDict):
41+
attachment_info: AttachmentInfoDict
42+
segment_id: str
43+
44+
45+
class SegmentAttachmentInfoResult(TypedDict):
46+
attachment_id: str
47+
attachment_info: AttachmentInfoDict
48+
segment_id: str
49+
50+
51+
class ChildChunkDetail(TypedDict):
52+
id: str
53+
content: str
54+
position: int
55+
score: float
56+
57+
58+
class SegmentChildMapDetail(TypedDict):
59+
max_score: float
60+
child_chunks: list[ChildChunkDetail]
61+
62+
63+
class SegmentRecord(TypedDict):
64+
segment: DocumentSegment
65+
score: NotRequired[float]
66+
child_chunks: NotRequired[list[ChildChunkDetail]]
67+
files: NotRequired[list[AttachmentInfoDict]]
68+
69+
70+
class DefaultRetrievalModelDict(TypedDict):
71+
search_method: RetrievalMethod | str
72+
reranking_enable: bool
73+
reranking_model: RerankingModelDict
74+
top_k: int
75+
score_threshold_enabled: bool
76+
77+
78+
default_retrieval_model: DefaultRetrievalModelDict = {
3979
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
4080
"reranking_enable": False,
4181
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@@ -56,9 +96,9 @@ def retrieve(
5696
query: str,
5797
top_k: int = 4,
5898
score_threshold: float | None = 0.0,
59-
reranking_model: dict | None = None,
99+
reranking_model: RerankingModelDict | None = None,
60100
reranking_mode: str = "reranking_model",
61-
weights: dict | None = None,
101+
weights: WeightsDict | None = None,
62102
document_ids_filter: list[str] | None = None,
63103
attachment_ids: list | None = None,
64104
):
@@ -235,7 +275,7 @@ def embedding_search(
235275
query: str,
236276
top_k: int,
237277
score_threshold: float | None,
238-
reranking_model: dict | None,
278+
reranking_model: RerankingModelDict | None,
239279
all_documents: list,
240280
retrieval_method: RetrievalMethod,
241281
exceptions: list,
@@ -277,8 +317,8 @@ def embedding_search(
277317
if documents:
278318
if (
279319
reranking_model
280-
and reranking_model.get("reranking_model_name")
281-
and reranking_model.get("reranking_provider_name")
320+
and reranking_model["reranking_model_name"]
321+
and reranking_model["reranking_provider_name"]
282322
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH
283323
):
284324
data_post_processor = DataPostProcessor(
@@ -288,8 +328,8 @@ def embedding_search(
288328
model_manager = ModelManager()
289329
is_support_vision = model_manager.check_model_support_vision(
290330
tenant_id=dataset.tenant_id,
291-
provider=reranking_model.get("reranking_provider_name") or "",
292-
model=reranking_model.get("reranking_model_name") or "",
331+
provider=reranking_model["reranking_provider_name"],
332+
model=reranking_model["reranking_model_name"],
293333
model_type=ModelType.RERANK,
294334
)
295335
if is_support_vision:
@@ -329,7 +369,7 @@ def full_text_index_search(
329369
query: str,
330370
top_k: int,
331371
score_threshold: float | None,
332-
reranking_model: dict | None,
372+
reranking_model: RerankingModelDict | None,
333373
all_documents: list,
334374
retrieval_method: str,
335375
exceptions: list,
@@ -349,8 +389,8 @@ def full_text_index_search(
349389
if documents:
350390
if (
351391
reranking_model
352-
and reranking_model.get("reranking_model_name")
353-
and reranking_model.get("reranking_provider_name")
392+
and reranking_model["reranking_model_name"]
393+
and reranking_model["reranking_provider_name"]
354394
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH
355395
):
356396
data_post_processor = DataPostProcessor(
@@ -459,7 +499,7 @@ def format_retrieval_documents(cls, documents: list[Document]) -> list[Retrieval
459499
segment_ids: list[str] = []
460500
index_node_segments: list[DocumentSegment] = []
461501
segments: list[DocumentSegment] = []
462-
attachment_map: dict[str, list[dict[str, Any]]] = {}
502+
attachment_map: dict[str, list[AttachmentInfoDict]] = {}
463503
child_chunk_map: dict[str, list[ChildChunk]] = {}
464504
doc_segment_map: dict[str, list[str]] = {}
465505
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
@@ -544,12 +584,12 @@ def format_retrieval_documents(cls, documents: list[Document]) -> list[Retrieval
544584
segment_summary_map[summary.chunk_id] = summary.summary_content
545585

546586
include_segment_ids = set()
547-
segment_child_map: dict[str, dict[str, Any]] = {}
548-
records: list[dict[str, Any]] = []
587+
segment_child_map: dict[str, SegmentChildMapDetail] = {}
588+
records: list[SegmentRecord] = []
549589

550590
for segment in segments:
551591
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
552-
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
592+
attachment_infos: list[AttachmentInfoDict] = attachment_map.get(segment.id, [])
553593
ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id)
554594

555595
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
@@ -560,14 +600,14 @@ def format_retrieval_documents(cls, documents: list[Document]) -> list[Retrieval
560600
max_score = summary_score_map.get(segment.id, 0.0)
561601

562602
if child_chunks or attachment_infos:
563-
child_chunk_details = []
603+
child_chunk_details: list[ChildChunkDetail] = []
564604
for child_chunk in child_chunks:
565605
child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id)
566606
if child_document:
567607
child_score = child_document.metadata.get("score", 0.0)
568608
else:
569609
child_score = 0.0
570-
child_chunk_detail = {
610+
child_chunk_detail: ChildChunkDetail = {
571611
"id": child_chunk.id,
572612
"content": child_chunk.content,
573613
"position": child_chunk.position,
@@ -580,7 +620,7 @@ def format_retrieval_documents(cls, documents: list[Document]) -> list[Retrieval
580620
if file_document:
581621
max_score = max(max_score, file_document.metadata.get("score", 0.0))
582622

583-
map_detail = {
623+
map_detail: SegmentChildMapDetail = {
584624
"max_score": max_score,
585625
"child_chunks": child_chunk_details,
586626
}
@@ -593,7 +633,7 @@ def format_retrieval_documents(cls, documents: list[Document]) -> list[Retrieval
593633
"max_score": summary_score,
594634
"child_chunks": [],
595635
}
596-
record: dict[str, Any] = {
636+
record: SegmentRecord = {
597637
"segment": segment,
598638
}
599639
records.append(record)
@@ -617,19 +657,19 @@ def format_retrieval_documents(cls, documents: list[Document]) -> list[Retrieval
617657
if file_doc:
618658
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
619659

620-
record = {
660+
another_record: SegmentRecord = {
621661
"segment": segment,
622662
"score": max_score,
623663
}
624-
records.append(record)
664+
records.append(another_record)
625665

626666
# Add child chunks information to records
627667
for record in records:
628668
if record["segment"].id in segment_child_map:
629-
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
630-
record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore
669+
record["child_chunks"] = segment_child_map[record["segment"].id]["child_chunks"]
670+
record["score"] = segment_child_map[record["segment"].id]["max_score"]
631671
if record["segment"].id in attachment_map:
632-
record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment]
672+
record["files"] = attachment_map[record["segment"].id]
633673

634674
result: list[RetrievalSegments] = []
635675
for record in records:
@@ -693,9 +733,9 @@ def _retrieve(
693733
query: str | None = None,
694734
top_k: int = 4,
695735
score_threshold: float | None = 0.0,
696-
reranking_model: dict | None = None,
736+
reranking_model: RerankingModelDict | None = None,
697737
reranking_mode: str = "reranking_model",
698-
weights: dict | None = None,
738+
weights: WeightsDict | None = None,
699739
document_ids_filter: list[str] | None = None,
700740
attachment_id: str | None = None,
701741
):
@@ -807,7 +847,7 @@ def _retrieve(
807847
@classmethod
808848
def get_segment_attachment_info(
809849
cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session
810-
) -> dict[str, Any] | None:
850+
) -> SegmentAttachmentResult | None:
811851
upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first()
812852
if upload_file:
813853
attachment_binding = (
@@ -816,7 +856,7 @@ def get_segment_attachment_info(
816856
.first()
817857
)
818858
if attachment_binding:
819-
attachment_info = {
859+
attachment_info: AttachmentInfoDict = {
820860
"id": upload_file.id,
821861
"name": upload_file.name,
822862
"extension": "." + upload_file.extension,
@@ -828,8 +868,10 @@ def get_segment_attachment_info(
828868
return None
829869

830870
@classmethod
831-
def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]:
832-
attachment_infos = []
871+
def get_segment_attachment_infos(
872+
cls, attachment_ids: list[str], session: Session
873+
) -> list[SegmentAttachmentInfoResult]:
874+
attachment_infos: list[SegmentAttachmentInfoResult] = []
833875
upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all()
834876
if upload_files:
835877
upload_file_ids = [upload_file.id for upload_file in upload_files]
@@ -843,7 +885,7 @@ def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Sessio
843885
if attachment_bindings:
844886
for upload_file in upload_files:
845887
attachment_binding = attachment_binding_map.get(upload_file.id)
846-
attachment_info = {
888+
info: AttachmentInfoDict = {
847889
"id": upload_file.id,
848890
"name": upload_file.name,
849891
"extension": "." + upload_file.extension,
@@ -855,7 +897,7 @@ def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Sessio
855897
attachment_infos.append(
856898
{
857899
"attachment_id": attachment_binding.attachment_id,
858-
"attachment_info": attachment_info,
900+
"attachment_info": info,
859901
"segment_id": attachment_binding.segment_id,
860902
}
861903
)

0 commit comments

Comments
 (0)