|
1 | 1 | from typing import Tuple, List, Literal |
2 | 2 | from pathlib import Path |
3 | | -from enum import Enum |
4 | | -from typing_extensions import Annotated |
5 | 3 |
|
6 | 4 | from langchain_core.tools import tool |
7 | 5 | from langchain.vectorstores.base import VectorStore |
|
13 | 11 | DATA = BASE / "data" |
14 | 12 |
|
15 | 13 |
|
16 | | -class SourceTag(str, Enum): |
17 | | - TEXT = "text_chunk" |
18 | | - IMAGE = "image_text" |
19 | | - TABLE = "table_summary" |
20 | | - |
21 | | - |
22 | 14 | def initialize_retrieve_tool(vector_store: VectorStore): |
23 | 15 | @tool( |
24 | 16 | description=( |
25 | 17 | """ |
26 | | - Semantic search over your docs. Valid tags are |
27 | | - "text_chunk", "image_text", and "table_summary". |
28 | | - """ |
| 18 | + Semantic search over your docs. ONLY valid tags are |
| 19 | + "text_chunk" (chunks over pdf, word, txt, etc), |
| 20 | + "image_text" (texts extracted through OCR per image), or |
| 21 | + "table_summary" (summary cards of excel sheets or csv files) |
| 22 | + """ |
29 | 23 | ), |
30 | 24 | response_format="content_and_artifact", |
31 | 25 | ) |
32 | 26 | def retrieve( |
33 | | - query: str, |
34 | | - tag: Annotated[ |
35 | | - Literal["text_chunk", "image_text", "table_summary"], |
36 | | - """ |
37 | | - Select between |
38 | | - "text_chunk" (chunks over pdf, word, txt, etc), |
39 | | - "image_text" (texts extracted through OCR per image), or |
40 | | - "table_summary" (summary cards of excel sheets or csv files) |
41 | | - """, |
42 | | - ], |
| 27 | + query: str, tag: Literal["text_chunk", "image_text", "table_summary"] |
43 | 28 | ) -> Tuple[str, List[Document]]: |
44 | | - """ |
45 | | - Args: |
46 | | - query: keywords or natural-language question. |
47 | | - tag: which subset to search ("text_chunk", "image_text", "table_summary"). |
48 | | - Returns: |
49 | | - (summary_string, list_of_Documents) |
50 | | - """ |
51 | 29 | retrieved_docs = vector_store.similarity_search( |
52 | 30 | query, |
53 | | - filter={"source_type": tag}, |
54 | 31 | k=2, |
| 32 | + filter=lambda doc: doc.metadata.get("source_type") == tag, |
55 | 33 | ) |
56 | 34 | serialized = "\n\n".join( |
57 | 35 | (f"Source: {doc.metadata}\nContent: {doc.page_content}") |
|
0 commit comments