Skip to content

Commit 219fe97

Browse files
committed
feat: Add document_id_list parameter to query methods in vector classes
1 parent 09c5c9c commit 219fe97

File tree

4 files changed

+9
-3
lines changed

4 files changed

+9
-3
lines changed

apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_docum
6767
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
6868
embedding_value = embedding_model.embed_query(exec_problem_text)
6969
vector = VectorStore.get_embedding_vector()
70-
embedding_list = vector.query(exec_problem_text, embedding_value, knowledge_id_list, exclude_document_id_list,
70+
embedding_list = vector.query(exec_problem_text, embedding_value, knowledge_id_list, None, exclude_document_id_list,
7171
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
7272
if embedding_list is None:
7373
return []

apps/application/flow/step_node/search_knowledge_node/impl/base_search_knowledge_node.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,14 @@ def execute(self, knowledge_id_list, knowledge_setting, question, show_knowledge
7979
self.context['question'] = question
8080
self.context['show_knowledge'] = show_knowledge
8181

82+
document_id_list = None
8283
if search_scope_type == 'referencing': # 引用上一步知识库/文档
8384
if search_scope_source == 'knowledge': # 知识库
8485
knowledge_id_list = self.get_reference_content(search_scope_reference)
8586
else: # 文档
87+
document_id_list = self.get_reference_content(search_scope_reference)
8688
knowledge_id_list = QuerySet(Document).filter(
87-
id__in=self.get_reference_content(search_scope_reference)
89+
id__in=document_id_list
8890
).values_list(
8991
'knowledge_id', flat=True
9092
).distinct()
@@ -105,7 +107,7 @@ def execute(self, knowledge_id_list, knowledge_setting, question, show_knowledge
105107
QuerySet(Document).filter(
106108
knowledge_id__in=knowledge_id_list,
107109
is_active=False)]
108-
embedding_list = vector.query(question, embedding_value, knowledge_id_list, exclude_document_id_list,
110+
embedding_list = vector.query(question, embedding_value, knowledge_id_list, document_id_list, exclude_document_id_list,
109111
exclude_paragraph_id_list, True, knowledge_setting.get('top_n'),
110112
knowledge_setting.get('similarity'),
111113
SearchMode(knowledge_setting.get('search_mode')))

apps/knowledge/vector/base_vector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def search(self, query_text, knowledge_id_list: list[str], exclude_document_id_l
126126

127127
@abstractmethod
128128
def query(self, query_text: str, query_embedding: List[float], knowledge_id_list: list[str],
129+
document_id_list: list[str] | None,
129130
exclude_document_id_list: list[str],
130131
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
131132
search_mode: SearchMode):

apps/knowledge/vector/pg_vector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,16 @@ def hit_test(self, query_text, knowledge_id_list: list[str], exclude_document_id
9797
return search_handle.handle(query_set, query_text, embedding_query, top_number, similarity, search_mode)
9898

9999
def query(self, query_text: str, query_embedding: List[float], knowledge_id_list: list[str],
100+
document_id_list: list[str],
100101
exclude_document_id_list: list[str],
101102
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
102103
search_mode: SearchMode):
103104
exclude_dict = {}
104105
if knowledge_id_list is None or len(knowledge_id_list) == 0:
105106
return []
106107
query_set = QuerySet(Embedding).filter(knowledge_id__in=knowledge_id_list, is_active=is_active)
108+
if document_id_list is not None and len(document_id_list) > 0:
109+
query_set = query_set.filter(document_id__in=document_id_list)
107110
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
108111
query_set = query_set.exclude(document_id__in=exclude_document_id_list)
109112
if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0:

0 commit comments

Comments
 (0)