Skip to content

Commit ecf3781

Browse files
committed
transfered notebook to python scripts with DRY and fixing dir indexing erros
1 parent c170fe0 commit ecf3781

File tree

5 files changed

+138
-2
lines changed

5 files changed

+138
-2
lines changed

notebooks/experiments/00_semantic_search.ipynb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@
191191
" TextLoader,\n",
192192
" UnstructuredWordDocumentLoader,\n",
193193
" Unsc\n",
194-
" \n",
195194
")\n",
196195
"from langchain.document_loaders import UnstructuredFileLoader\n",
197196
"\n",

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ langchain-pinecone
66
pypdf
77
python-dotenv
88
pinecone
9-
langgraph
9+
langgraph
10+
unstructured[pdf,docx,pptx,md]

src/any_chatbot/agent.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import getpass
2+
import os
3+
from dotenv import load_dotenv
4+
from pathlib import Path
5+
6+
from langgraph.prebuilt import create_react_agent
7+
from langgraph.checkpoint.memory import MemorySaver
8+
from langchain.chat_models import init_chat_model
9+
10+
from any_chatbot.indexing import index_text_docs
11+
from any_chatbot.tools import initialize_retrieve_tool
12+
13+
load_dotenv()
14+
15+
BASE = Path(__file__).parent.parent.parent
16+
DATA = BASE / "data"
17+
OUTPUTS = BASE / "outputs"
18+
19+
# INDEXING
20+
embeddings, vector_store = index_text_docs(DATA)
21+
22+
# BUILD LLM
23+
if not os.environ.get("GOOGLE_API_KEY"):
24+
os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ")
25+
llm = init_chat_model("gemini-2.0-flash", model_provider="google_genai")
26+
27+
# LOAD TOOLS
28+
retrieve_tool = initialize_retrieve_tool(vector_store)
29+
30+
# BUILD AGENT
31+
# build checkpointer
32+
memory = MemorySaver()
33+
# build agent
34+
agent_executor = create_react_agent(llm, [retrieve_tool], checkpointer=memory)
35+
# save architecture graph image
36+
png_bytes = agent_executor.get_graph().draw_mermaid_png()
37+
# save to file
38+
with open(OUTPUTS / "graph.png", "wb") as f:
39+
f.write(png_bytes)
40+
print("Wrote graph.png")
41+
42+
# PROMPT
43+
# specify an ID for the thread
44+
import random
45+
# config = {"configurable": {"thread_id": "abc123"}}
46+
config = {"configurable": {"thread_id": random.random()}}
47+
48+
input_message = (
49+
"First retrieve what the revenue for Nike in 2023 was using the functional call.\n\n"
50+
"Once you get the answer, do a second retrieve to tell me which distribution centers nike have.\n\n"
51+
"Once you get the second answer,, tell me how many employees nike has. You can retreive MULTIPLE TIMES\n\n"
52+
"Base your answers only on the retrieved information thorugh the functional call you have."
53+
)
54+
55+
for event in agent_executor.stream(
56+
{"messages": [{"role": "user", "content": input_message}]},
57+
stream_mode="values",
58+
config=config,
59+
):
60+
event["messages"][-1].pretty_print()
61+

src/any_chatbot/indexing.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from dotenv import load_dotenv
2+
from pathlib import Path
3+
4+
from langchain_core.vectorstores import InMemoryVectorStore
5+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
6+
from langchain_community.document_loaders import DirectoryLoader, UnstructuredFileLoader
7+
from langchain_text_splitters import RecursiveCharacterTextSplitter
8+
9+
load_dotenv()
10+
11+
BASE = Path(__file__).parent.parent.parent
12+
DATA = BASE / "data"
13+
14+
def index_text_docs(
15+
data_pth: Path = DATA,
16+
):
17+
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
18+
vector_store = InMemoryVectorStore(embeddings)
19+
20+
# Load the text documents
21+
loader = DirectoryLoader(
22+
str(data_pth),
23+
glob=[
24+
"**/*.pdf",
25+
"**/*.docx",
26+
"**/*.pptx",
27+
"**/*.md",
28+
"**/*.html",
29+
"**/*.txt",
30+
],
31+
loader_cls=UnstructuredFileLoader
32+
)
33+
print(f"Loading docs from {data_pth}")
34+
docs = loader.load()
35+
print(f"Loaded {len(docs)} docs")
36+
37+
# Split the texts
38+
text_splitter = RecursiveCharacterTextSplitter(
39+
chunk_size=1000,
40+
chunk_overlap=200,
41+
add_start_index=True,
42+
separators=["\n\n", "\n", " ", ""],
43+
)
44+
all_splits = text_splitter.split_documents(docs)
45+
print(len(all_splits))
46+
47+
# index the docs
48+
ids = vector_store.add_documents(documents=all_splits)
49+
print(len(ids))
50+
51+
return embeddings, vector_store
52+
53+
54+

src/any_chatbot/tools.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Tuple, List
2+
from langchain_core.tools import tool
3+
from langchain.vectorstores.base import VectorStore
4+
from langchain.schema import Document
5+
6+
def initialize_retrieve_tool(vector_store: VectorStore):
7+
@tool(
8+
description="Retrieve information related to a query",
9+
response_format="content_and_artifact",
10+
)
11+
def retrieve(
12+
query: str
13+
) -> Tuple[str, List[Document]]:
14+
retrieved_docs = vector_store.similarity_search(query, k=3)
15+
serialized = "\n\n".join(
16+
(f"Source: {doc.metadata}\nContent: {doc.page_content}")
17+
for doc in retrieved_docs
18+
)
19+
return serialized, retrieved_docs
20+
21+
return retrieve

0 commit comments

Comments
 (0)