Skip to content

Commit 6207999

Browse files
authored
fix:Knowledge Base with Parent-Child segment mode not support in Agent (langgenius#13663)
1 parent 4e7e172 commit 6207999

File tree

1 file changed

+46
-41
lines changed

1 file changed

+46
-41
lines changed

api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
from pydantic import BaseModel, Field
44

55
from core.rag.datasource.retrieval_service import RetrievalService
6+
from core.rag.entities.context_entities import DocumentContext
67
from core.rag.models.document import Document as RetrievalDocument
78
from core.rag.retrieval.retrieval_methods import RetrievalMethod
89
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
910
from extensions.ext_database import db
10-
from models.dataset import Dataset, Document, DocumentSegment
11+
from models.dataset import Dataset
12+
from models.dataset import Document as DatasetDocument
1113
from services.external_knowledge_service import ExternalDatasetService
1214

1315
default_retrieval_model = {
@@ -54,7 +56,6 @@ def _run(self, query: str) -> str:
5456

5557
if not dataset:
5658
return ""
57-
5859
for hit_callback in self.hit_callbacks:
5960
hit_callback.on_query(query, dataset.id)
6061
if dataset.provider == "external":
@@ -125,7 +126,6 @@ def _run(self, query: str) -> str:
125126
)
126127
else:
127128
documents = []
128-
129129
for hit_callback in self.hit_callbacks:
130130
hit_callback.on_tool_end(documents)
131131
document_score_list = {}
@@ -134,50 +134,46 @@ def _run(self, query: str) -> str:
134134
if item.metadata is not None and item.metadata.get("score"):
135135
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
136136
document_context_list = []
137-
index_node_ids = [document.metadata["doc_id"] for document in documents]
138-
segments = DocumentSegment.query.filter(
139-
DocumentSegment.dataset_id == self.dataset_id,
140-
DocumentSegment.completed_at.isnot(None),
141-
DocumentSegment.status == "completed",
142-
DocumentSegment.enabled == True,
143-
DocumentSegment.index_node_id.in_(index_node_ids),
144-
).all()
145-
146-
if segments:
147-
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
148-
sorted_segments = sorted(
149-
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
150-
)
151-
for segment in sorted_segments:
137+
records = RetrievalService.format_retrieval_documents(documents)
138+
if records:
139+
for record in records:
140+
segment = record.segment
152141
if segment.answer:
153142
document_context_list.append(
154-
f"question:{segment.get_sign_content()} answer:{segment.answer}"
143+
DocumentContext(
144+
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
145+
score=record.score,
146+
)
155147
)
156148
else:
157-
document_context_list.append(segment.get_sign_content())
149+
document_context_list.append(
150+
DocumentContext(
151+
content=segment.get_sign_content(),
152+
score=record.score,
153+
)
154+
)
155+
retrieval_resource_list = []
158156
if self.return_resource:
159-
context_list = []
160-
resource_number = 1
161-
for segment in sorted_segments:
162-
document_segment = Document.query.filter(
163-
Document.id == segment.document_id,
164-
Document.enabled == True,
165-
Document.archived == False,
157+
for record in records:
158+
segment = record.segment
159+
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
160+
document = DatasetDocument.query.filter(
161+
DatasetDocument.id == segment.document_id,
162+
DatasetDocument.enabled == True,
163+
DatasetDocument.archived == False,
166164
).first()
167-
if not document_segment:
168-
continue
169-
if dataset and document_segment:
165+
if dataset and document:
170166
source = {
171-
"position": resource_number,
172167
"dataset_id": dataset.id,
173168
"dataset_name": dataset.name,
174-
"document_id": document_segment.id,
175-
"document_name": document_segment.name,
176-
"data_source_type": document_segment.data_source_type,
169+
"document_id": document.id, # type: ignore
170+
"document_name": document.name, # type: ignore
171+
"data_source_type": document.data_source_type, # type: ignore
177172
"segment_id": segment.id,
178173
"retriever_from": self.retriever_from,
179-
"score": document_score_list.get(segment.index_node_id, None),
174+
"score": record.score or 0.0,
180175
}
176+
181177
if self.retriever_from == "dev":
182178
source["hit_count"] = segment.hit_count
183179
source["word_count"] = segment.word_count
@@ -187,10 +183,19 @@ def _run(self, query: str) -> str:
187183
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
188184
else:
189185
source["content"] = segment.content
190-
context_list.append(source)
191-
resource_number += 1
192-
193-
for hit_callback in self.hit_callbacks:
194-
hit_callback.return_retriever_resource_info(context_list)
186+
retrieval_resource_list.append(source)
195187

196-
return str("\n".join(document_context_list))
188+
if self.return_resource and retrieval_resource_list:
189+
retrieval_resource_list = sorted(
190+
retrieval_resource_list,
191+
key=lambda x: x.get("score") or 0.0,
192+
reverse=True,
193+
)
194+
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
195+
item["position"] = position # type: ignore
196+
for hit_callback in self.hit_callbacks:
197+
hit_callback.return_retriever_resource_info(retrieval_resource_list)
198+
if document_context_list:
199+
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
200+
return str("\n".join([document_context.content for document_context in document_context_list]))
201+
return ""

0 commit comments

Comments
 (0)