Skip to content

Commit 07f0717

Browse files
committed
added semantic search over tags without errors
1 parent f9908d1 commit 07f0717

File tree

1 file changed

+7
-29
lines changed

1 file changed

+7
-29
lines changed

src/any_chatbot/tools.py

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from typing import Tuple, List, Literal
22
from pathlib import Path
3-
from enum import Enum
4-
from typing_extensions import Annotated
53

64
from langchain_core.tools import tool
75
from langchain.vectorstores.base import VectorStore
@@ -13,45 +11,25 @@
1311
DATA = BASE / "data"
1412

1513

16-
class SourceTag(str, Enum):
17-
TEXT = "text_chunk"
18-
IMAGE = "image_text"
19-
TABLE = "table_summary"
20-
21-
2214
def initialize_retrieve_tool(vector_store: VectorStore):
2315
@tool(
2416
description=(
2517
"""
26-
Semantic search over your docs. Valid tags are
27-
"text_chunk", "image_text", and "table_summary".
28-
"""
18+
Semantic search over your docs. ONLY valid tags are
19+
"text_chunk" (chunks over pdf, word, txt, etc),
20+
"image_text" (texts extracted through OCR per image), or
21+
"table_summary" (summary cards of excel sheets or csv files)
22+
"""
2923
),
3024
response_format="content_and_artifact",
3125
)
3226
def retrieve(
33-
query: str,
34-
tag: Annotated[
35-
Literal["text_chunk", "image_text", "table_summary"],
36-
"""
37-
Select between
38-
"text_chunk" (chunks over pdf, word, txt, etc),
39-
"image_text" (texts extracted through OCR per image), or
40-
"table_summary" (summary cards of excel sheets or csv files)
41-
""",
42-
],
27+
query: str, tag: Literal["text_chunk", "image_text", "table_summary"]
4328
) -> Tuple[str, List[Document]]:
44-
"""
45-
Args:
46-
query: keywords or natural-language question.
47-
tag: which subset to search ("text_chunk", "image_text", "table_summary").
48-
Returns:
49-
(summary_string, list_of_Documents)
50-
"""
5129
retrieved_docs = vector_store.similarity_search(
5230
query,
53-
filter={"source_type": tag},
5431
k=2,
32+
filter=lambda doc: doc.metadata.get("source_type") == tag,
5533
)
5634
serialized = "\n\n".join(
5735
(f"Source: {doc.metadata}\nContent: {doc.page_content}")

0 commit comments

Comments
 (0)