Skip to content

Commit cdf6cd5

Browse files
authored
Fixed graphrag macos errors and really long load times for parsing documents (#876)
1 parent bbffacf commit cdf6cd5

File tree

3 files changed

+144
-195
lines changed

3 files changed

+144
-195
lines changed

recipes/natural_language_processing/graph-rag/app/Containerfile

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,7 @@ RUN chown -R 1001:0 /graph-rag
99
COPY requirements.txt .
1010
COPY rag_app.py .
1111

12-
# Detect architecture and install Rust only on ARM (aarch64/arm64)
13-
RUN ARCH=$(uname -m) && \
14-
if [ "$ARCH" = "aarch64" ] || [ "$ARCH" = "arm64" ]; then \
15-
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y && \
16-
source "$HOME/.cargo/env" && \
17-
rustc --version && \
18-
cargo --version; \
19-
fi && \
20-
pip install --upgrade pip && \
12+
RUN pip install --upgrade pip && \
2113
pip install --no-cache-dir --upgrade -r /graph-rag/requirements.txt
2214

2315
# Expose the port for the application
Lines changed: 136 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -1,196 +1,183 @@
1-
import sys
21
import os
3-
import shutil
4-
import nest_asyncio
2+
import fitz # PyMuPDF
53
import streamlit as st
6-
import fitz
7-
import logging
8-
9-
# logging.basicConfig(level=logging.DEBUG)
10-
11-
from lightrag import LightRAG, QueryParam
12-
from lightrag.llm.hf import hf_embed
13-
from lightrag.llm.openai import openai_complete_if_cache
14-
from lightrag.utils import EmbeddingFunc, encode_string_by_tiktoken, truncate_list_by_token_size, decode_tokens_by_tiktoken
15-
from transformers import AutoModel, AutoTokenizer
16-
17-
# Apply nest_asyncio to solve event loop issues
18-
nest_asyncio.apply()
19-
20-
WORKING_DIR = "rag_data"
21-
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
22-
LLM_MODEL = "dummy"
23-
API_KEY = "dummy"
24-
4+
from typing import List
5+
from langchain_core.documents import Document
6+
from langchain_core.vectorstores import InMemoryVectorStore
7+
from langchain_huggingface import HuggingFaceEmbeddings
8+
from langchain_graph_retriever import GraphRetriever
9+
from graph_retriever.strategies import Eager
10+
from langchain_core.prompts import ChatPromptTemplate
11+
from langchain_core.runnables import RunnablePassthrough
12+
from langchain_core.output_parsers import StrOutputParser
13+
from langchain_openai import ChatOpenAI
14+
15+
# Configuration
16+
EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
2517
model_service = os.getenv("MODEL_ENDPOINT",
2618
"http://localhost:8001")
2719
model_service = f"{model_service}/v1"
28-
29-
# Check if folder exists
30-
if not os.path.exists(WORKING_DIR):
31-
os.mkdir(WORKING_DIR)
32-
33-
async def llm_model_func(
34-
prompt: str, system_prompt: str = None, history_messages: list[str] = [], **kwargs
35-
) -> str:
36-
"""LLM function to ensure total tokens (prompt + system_prompt + history_messages) <= 2048."""
37-
# Calculate token sizes
38-
prompt_tokens = len(encode_string_by_tiktoken(prompt))
39-
40-
# Calculate remaining tokens for history_messages
41-
max_total_tokens = 1000
42-
43-
# If the prompt itself exceeds the token limit, truncate it
44-
if prompt_tokens > max_total_tokens:
45-
print("Warning: Prompt exceeds token limit. Truncating prompt.")
46-
truncated_prompt = encode_string_by_tiktoken(prompt)[:max_total_tokens]
47-
prompt = decode_tokens_by_tiktoken(truncated_prompt)
48-
prompt_tokens = len(truncated_prompt)
49-
50-
# Truncate history_messages to fit within the remaining tokens
51-
52-
# Log token sizes for debugging
53-
print(f"Prompt tokens: {prompt_tokens}")
54-
55-
# Call the LLM with truncated prompt and history_messages
56-
return await openai_complete_if_cache(
57-
model=LLM_MODEL,
58-
prompt=prompt,
59-
system_prompt=system_prompt,
60-
# history_messages=history_messages,
61-
base_url=model_service,
62-
api_key=API_KEY,
63-
**kwargs,
64-
)
65-
66-
rag = LightRAG(
67-
working_dir=WORKING_DIR,
68-
llm_model_func=llm_model_func,
69-
chunk_token_size = 256,
70-
chunk_overlap_token_size = 50,
71-
llm_model_max_token_size=1000,
72-
llm_model_name=LLM_MODEL,
73-
embedding_func=EmbeddingFunc(
74-
embedding_dim=384,
75-
max_token_size=5000,
76-
func=lambda texts: hf_embed(
77-
texts,
78-
tokenizer=AutoTokenizer.from_pretrained(EMBEDDING_MODEL),
79-
embed_model=AutoModel.from_pretrained(EMBEDDING_MODEL),
80-
),
81-
),
82-
)
20+
LLM_MODEL = "local-model"
21+
WORKING_DIR = "graph_rag_data"
8322

8423
# Initialize session state
8524
if 'uploaded_file_previous' not in st.session_state:
8625
st.session_state.uploaded_file_previous = None
8726

88-
if 'rag_initialized' not in st.session_state:
89-
st.session_state.rag_initialized = False
27+
if 'retriever' not in st.session_state:
28+
st.session_state.retriever = None
29+
30+
if 'chain' not in st.session_state:
31+
st.session_state.chain = None
9032

9133
if 'user_query' not in st.session_state:
9234
st.session_state.user_query = ''
93-
if 'last_submission' not in st.session_state:
94-
st.session_state.last_submission = ''
9535

96-
def pdf_to_text(pdf_path, output_path):
36+
def pdf_to_text(pdf_path: str) -> str:
37+
"""Extract text from PDF file."""
9738
try:
9839
doc = fitz.open(pdf_path)
9940
text = ''
10041
for page in doc:
10142
text += page.get_text()
102-
with open(output_path, 'w', encoding='utf-8') as file:
103-
file.write(text)
43+
return text
10444
except Exception as e:
10545
st.error(f"Error extracting text from PDF: {e}")
10646
raise
10747

108-
async def async_query(query, mode="mix"):
109-
print('\n')
110-
print("query: ", query)
111-
try:
112-
with st.spinner("Processing your query..."):
113-
stream = rag.query(query, param=QueryParam(mode=mode, stream=True, max_token_for_text_unit=1750, max_token_for_global_context=1750, max_token_for_local_context=1750))
48+
def create_documents_from_text(text: str) -> List[Document]:
49+
"""Create LangChain Documents from text with basic metadata."""
50+
chunks = text.split('\n\n') # Simple paragraph-based chunking
51+
documents = []
52+
for i, chunk in enumerate(chunks):
53+
if chunk.strip(): # Skip empty chunks
54+
documents.append(
55+
Document(
56+
page_content=chunk.strip(),
57+
metadata={"id": f"chunk_{i}", "source": "uploaded_file"}
58+
)
59+
)
60+
return documents
11461

115-
# Create a placeholder for the streamed content
116-
output_placeholder = st.empty()
117-
118-
# Manually consume the stream and write to Streamlit
119-
response = ""
62+
def setup_retriever(documents: List[Document]) -> GraphRetriever:
63+
"""Set up the Graph Retriever with HuggingFace embeddings."""
64+
# Initialize embeddings
65+
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
66+
67+
# Create vector store
68+
vector_store = InMemoryVectorStore.from_documents(
69+
documents=documents,
70+
embedding=embeddings,
71+
)
72+
73+
# Create graph retriever
74+
retriever = GraphRetriever(
75+
store=vector_store,
76+
edges=[("source", "source")], # Simple edge - can customize based on your metadata
77+
strategy=Eager(k=5, start_k=1, max_depth=2),
78+
)
79+
80+
return retriever
81+
82+
def setup_llm_chain(retriever: GraphRetriever):
83+
"""Set up the LLM chain with the retriever."""
84+
llm = ChatOpenAI(
85+
base_url=model_service,
86+
api_key="dummy",
87+
model=LLM_MODEL,
88+
streaming=True,
89+
)
90+
91+
prompt = ChatPromptTemplate.from_template(
92+
"""Answer the question based only on the context provided.
12093
121-
# Check if stream is an async iterable
122-
if hasattr(stream, "__aiter__"):
123-
print("async")
124-
async for chunk in stream:
125-
response += chunk
126-
# Update the placeholder with the latest response
127-
output_placeholder.markdown(response, unsafe_allow_html=True)
128-
else:
129-
print("not async")
130-
st.write(stream)
131-
response = stream
94+
Context: {context}
13295
133-
# Store the final response in session state
134-
st.session_state.last_submission = response
96+
Question: {question}"""
97+
)
98+
99+
def format_docs(docs):
100+
return "\n\n".join(f"{doc.page_content}" for doc in docs)
101+
102+
chain = (
103+
{"context": retriever | format_docs, "question": RunnablePassthrough()}
104+
| prompt
105+
| llm
106+
| StrOutputParser()
107+
)
108+
109+
return chain
110+
111+
def process_query(query: str):
112+
"""Process user query using the Graph RAG chain."""
113+
if st.session_state.chain is None:
114+
st.error("Please upload and process a PDF file first.")
115+
return
116+
117+
try:
118+
st.subheader("Answer:")
119+
with st.spinner("Processing your query..."):
120+
# Stream output token-by-token
121+
response_placeholder = st.empty()
122+
123+
full_response = ""
124+
for chunk in st.session_state.chain.stream(query):
125+
full_response += chunk
126+
response_placeholder.markdown(full_response + "▌")
127+
128+
response_placeholder.markdown(full_response)
135129

136-
except ValueError as e:
137-
if "exceed context window" in str(e):
138-
st.error(
139-
"The tokens in your query exceed the model's context window. Please try a different query mode or shorten your query."
140-
)
141-
# Optionally, you could reset the query mode or suggest alternatives
142-
st.session_state.query_mode = "mix" # Default to "mix" mode
143-
st.session_state.user_query = '' # Clear the user query
144-
else:
145-
st.error(f"Error processing query: {e}")
146130
except Exception as e:
147131
st.error(f"Error processing query: {e}")
148132

149-
def query(query, mode="mix"):
150-
# Run the async function in the event loop
151-
import asyncio
152-
asyncio.run(async_query(query, mode))
153133

154134
# Streamlit UI
155-
st.title("GraphRAG Chatbot")
135+
st.title("Graph RAG with PDF Upload")
156136

157137
uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
158138

159139
if uploaded_file is not None:
160140
if uploaded_file.name != st.session_state.uploaded_file_previous:
161141
st.session_state.uploaded_file_previous = uploaded_file.name
162-
if os.path.exists(WORKING_DIR):
163-
shutil.rmtree(WORKING_DIR, ignore_errors=True)
164-
os.makedirs(WORKING_DIR)
165-
166-
with open("temp.pdf", "wb") as f:
142+
143+
# Create working directory if it doesn't exist
144+
if not os.path.exists(WORKING_DIR):
145+
os.makedirs(WORKING_DIR)
146+
147+
# Save uploaded file temporarily
148+
temp_pdf_path = os.path.join(WORKING_DIR, "temp.pdf")
149+
with open(temp_pdf_path, "wb") as f:
167150
f.write(uploaded_file.getbuffer())
168151

169152
try:
170153
with st.spinner("Processing PDF..."):
171-
pdf_to_text("temp.pdf", "document.txt")
172-
with open("document.txt", "r", encoding="utf-8") as f:
173-
rag.insert(f.read())
174-
st.session_state.rag_initialized = True
154+
text = pdf_to_text(temp_pdf_path)
155+
156+
documents = create_documents_from_text(text)
157+
158+
# Set up retriever and chain
159+
st.session_state.retriever = setup_retriever(documents)
160+
st.session_state.chain = setup_llm_chain(st.session_state.retriever)
161+
162+
st.success("PDF processed successfully! You can now ask questions.")
163+
175164
except Exception as e:
176165
st.error(f"Error processing PDF: {e}")
177166
finally:
178-
if os.path.exists("temp.pdf"):
179-
os.remove("temp.pdf")
180-
181-
if st.session_state.rag_initialized:
182-
query_mode = st.radio(
183-
"Select query mode:",
184-
options=["local", "global", "naive", "hybrid", "mix"],
185-
index=3,
186-
key="mode"
187-
)
188-
st.session_state.query_mode = query_mode
167+
# Clean up temporary file
168+
if os.path.exists(temp_pdf_path):
169+
os.remove(temp_pdf_path)
170+
171+
# Query section
172+
if st.session_state.retriever is not None:
173+
st.subheader("Ask a Question")
189174

190-
# Use a unique key for the text input to avoid conflicts
191-
user_query = st.text_input("Enter your query:", key="query_input")
175+
st.text_input(
176+
"Enter your question about the document:",
177+
key="query_input"
178+
)
179+
user_query = st.session_state.query_input
192180

193-
if st.button("Submit"):
194-
if user_query.strip():
195-
st.session_state.user_query = user_query
196-
query(st.session_state.user_query, mode=st.session_state.query_mode)
181+
if user_query.strip() and user_query != st.session_state.user_query:
182+
st.session_state.user_query = user_query
183+
process_query(user_query)
Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,7 @@
1-
lightrag-hku==1.1.7
2-
numpy==1.26.4
3-
pydantic==2.10.6
4-
python-dotenv==1.0.1
5-
pipmaster==0.4.0
6-
httpx==0.28.1
7-
nest_asyncio==1.6.0
8-
future==1.0.0
9-
setuptools==75.8.2
10-
tenacity==9.0.0
11-
PyMuPDF==1.25.5
12-
streamlit==1.42.0
13-
tiktoken
14-
torch
15-
transformers
16-
matplotlib
17-
scikit-learn
18-
POT==0.9.5
19-
anytree==2.12.1
20-
autograd==1.7.0
21-
beartype==0.18.5
22-
gensim==4.3.3
23-
graspologic==3.4.1
24-
hyppo==0.4.0
25-
llvmlite==0.44.0
26-
numba==0.61.2
27-
patsy==1.0.1
28-
pynndescent==0.5.13
29-
seaborn==0.13.2
30-
smart-open==7.1.0
31-
statsmodels==0.14.4
32-
umap-learn==0.5.7
33-
wrapt==1.17.2
34-
nano-vectordb==0.0.4.3
35-
jiter==0.8.2
36-
distro==1.9.0
37-
openai==1.64.0
1+
streamlit==1.45.1
2+
langchain-graph-retriever==0.8.0
3+
langchain-huggingface==0.2.0
4+
langchain-openai==0.3.17
5+
transformers==4.52.1
6+
torch==2.7.0
7+
PyMuPDF==1.25.5

0 commit comments

Comments
 (0)