1- import sys
21import os
3- import shutil
4- import nest_asyncio
2+ import fitz # PyMuPDF
53import streamlit as st
6- import fitz
7- import logging
8-
9- # logging.basicConfig(level=logging.DEBUG)
10-
11- from lightrag import LightRAG , QueryParam
12- from lightrag .llm .hf import hf_embed
13- from lightrag .llm .openai import openai_complete_if_cache
14- from lightrag .utils import EmbeddingFunc , encode_string_by_tiktoken , truncate_list_by_token_size , decode_tokens_by_tiktoken
15- from transformers import AutoModel , AutoTokenizer
16-
17- # Apply nest_asyncio to solve event loop issues
18- nest_asyncio .apply ()
19-
20- WORKING_DIR = "rag_data"
21- EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
22- LLM_MODEL = "dummy"
23- API_KEY = "dummy"
24-
4+ from typing import List
5+ from langchain_core .documents import Document
6+ from langchain_core .vectorstores import InMemoryVectorStore
7+ from langchain_huggingface import HuggingFaceEmbeddings
8+ from langchain_graph_retriever import GraphRetriever
9+ from graph_retriever .strategies import Eager
10+ from langchain_core .prompts import ChatPromptTemplate
11+ from langchain_core .runnables import RunnablePassthrough
12+ from langchain_core .output_parsers import StrOutputParser
13+ from langchain_openai import ChatOpenAI
14+
15+ # Configuration
16+ EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
2517model_service = os .getenv ("MODEL_ENDPOINT" ,
2618 "http://localhost:8001" )
2719model_service = f"{ model_service } /v1"
28-
29- # Check if folder exists
30- if not os .path .exists (WORKING_DIR ):
31- os .mkdir (WORKING_DIR )
32-
33- async def llm_model_func (
34- prompt : str , system_prompt : str = None , history_messages : list [str ] = [], ** kwargs
35- ) -> str :
36- """LLM function to ensure total tokens (prompt + system_prompt + history_messages) <= 2048."""
37- # Calculate token sizes
38- prompt_tokens = len (encode_string_by_tiktoken (prompt ))
39-
40- # Calculate remaining tokens for history_messages
41- max_total_tokens = 1000
42-
43- # If the prompt itself exceeds the token limit, truncate it
44- if prompt_tokens > max_total_tokens :
45- print ("Warning: Prompt exceeds token limit. Truncating prompt." )
46- truncated_prompt = encode_string_by_tiktoken (prompt )[:max_total_tokens ]
47- prompt = decode_tokens_by_tiktoken (truncated_prompt )
48- prompt_tokens = len (truncated_prompt )
49-
50- # Truncate history_messages to fit within the remaining tokens
51-
52- # Log token sizes for debugging
53- print (f"Prompt tokens: { prompt_tokens } " )
54-
55- # Call the LLM with truncated prompt and history_messages
56- return await openai_complete_if_cache (
57- model = LLM_MODEL ,
58- prompt = prompt ,
59- system_prompt = system_prompt ,
60- # history_messages=history_messages,
61- base_url = model_service ,
62- api_key = API_KEY ,
63- ** kwargs ,
64- )
65-
66- rag = LightRAG (
67- working_dir = WORKING_DIR ,
68- llm_model_func = llm_model_func ,
69- chunk_token_size = 256 ,
70- chunk_overlap_token_size = 50 ,
71- llm_model_max_token_size = 1000 ,
72- llm_model_name = LLM_MODEL ,
73- embedding_func = EmbeddingFunc (
74- embedding_dim = 384 ,
75- max_token_size = 5000 ,
76- func = lambda texts : hf_embed (
77- texts ,
78- tokenizer = AutoTokenizer .from_pretrained (EMBEDDING_MODEL ),
79- embed_model = AutoModel .from_pretrained (EMBEDDING_MODEL ),
80- ),
81- ),
82- )
20+ LLM_MODEL = "local-model"
21+ WORKING_DIR = "graph_rag_data"
8322
8423# Initialize session state
8524if 'uploaded_file_previous' not in st .session_state :
8625 st .session_state .uploaded_file_previous = None
8726
88- if 'rag_initialized' not in st .session_state :
89- st .session_state .rag_initialized = False
27+ if 'retriever' not in st .session_state :
28+ st .session_state .retriever = None
29+
30+ if 'chain' not in st .session_state :
31+ st .session_state .chain = None
9032
9133if 'user_query' not in st .session_state :
9234 st .session_state .user_query = ''
93- if 'last_submission' not in st .session_state :
94- st .session_state .last_submission = ''
9535
96- def pdf_to_text (pdf_path , output_path ):
36+ def pdf_to_text (pdf_path : str ) -> str :
37+ """Extract text from PDF file."""
9738 try :
9839 doc = fitz .open (pdf_path )
9940 text = ''
10041 for page in doc :
10142 text += page .get_text ()
102- with open (output_path , 'w' , encoding = 'utf-8' ) as file :
103- file .write (text )
43+ return text
10444 except Exception as e :
10545 st .error (f"Error extracting text from PDF: { e } " )
10646 raise
10747
108- async def async_query (query , mode = "mix" ):
109- print ('\n ' )
110- print ("query: " , query )
111- try :
112- with st .spinner ("Processing your query..." ):
113- stream = rag .query (query , param = QueryParam (mode = mode , stream = True , max_token_for_text_unit = 1750 , max_token_for_global_context = 1750 , max_token_for_local_context = 1750 ))
48+ def create_documents_from_text (text : str ) -> List [Document ]:
49+ """Create LangChain Documents from text with basic metadata."""
50+ chunks = text .split ('\n \n ' ) # Simple paragraph-based chunking
51+ documents = []
52+ for i , chunk in enumerate (chunks ):
53+ if chunk .strip (): # Skip empty chunks
54+ documents .append (
55+ Document (
56+ page_content = chunk .strip (),
57+ metadata = {"id" : f"chunk_{ i } " , "source" : "uploaded_file" }
58+ )
59+ )
60+ return documents
11461
115- # Create a placeholder for the streamed content
116- output_placeholder = st .empty ()
117-
118- # Manually consume the stream and write to Streamlit
119- response = ""
62+ def setup_retriever (documents : List [Document ]) -> GraphRetriever :
63+ """Set up the Graph Retriever with HuggingFace embeddings."""
64+ # Initialize embeddings
65+ embeddings = HuggingFaceEmbeddings (model_name = EMBEDDING_MODEL )
66+
67+ # Create vector store
68+ vector_store = InMemoryVectorStore .from_documents (
69+ documents = documents ,
70+ embedding = embeddings ,
71+ )
72+
73+ # Create graph retriever
74+ retriever = GraphRetriever (
75+ store = vector_store ,
76+ edges = [("source" , "source" )], # Simple edge - can customize based on your metadata
77+ strategy = Eager (k = 5 , start_k = 1 , max_depth = 2 ),
78+ )
79+
80+ return retriever
81+
82+ def setup_llm_chain (retriever : GraphRetriever ):
83+ """Set up the LLM chain with the retriever."""
84+ llm = ChatOpenAI (
85+ base_url = model_service ,
86+ api_key = "dummy" ,
87+ model = LLM_MODEL ,
88+ streaming = True ,
89+ )
90+
91+ prompt = ChatPromptTemplate .from_template (
92+ """Answer the question based only on the context provided.
12093
121- # Check if stream is an async iterable
122- if hasattr (stream , "__aiter__" ):
123- print ("async" )
124- async for chunk in stream :
125- response += chunk
126- # Update the placeholder with the latest response
127- output_placeholder .markdown (response , unsafe_allow_html = True )
128- else :
129- print ("not async" )
130- st .write (stream )
131- response = stream
94+ Context: {context}
13295
133- # Store the final response in session state
134- st .session_state .last_submission = response
96+ Question: {question}"""
97+ )
98+
99+ def format_docs (docs ):
100+ return "\n \n " .join (f"{ doc .page_content } " for doc in docs )
101+
102+ chain = (
103+ {"context" : retriever | format_docs , "question" : RunnablePassthrough ()}
104+ | prompt
105+ | llm
106+ | StrOutputParser ()
107+ )
108+
109+ return chain
110+
111+ def process_query (query : str ):
112+ """Process user query using the Graph RAG chain."""
113+ if st .session_state .chain is None :
114+ st .error ("Please upload and process a PDF file first." )
115+ return
116+
117+ try :
118+ st .subheader ("Answer:" )
119+ with st .spinner ("Processing your query..." ):
120+ # Stream output token-by-token
121+ response_placeholder = st .empty ()
122+
123+ full_response = ""
124+ for chunk in st .session_state .chain .stream (query ):
125+ full_response += chunk
126+ response_placeholder .markdown (full_response + "▌" )
127+
128+ response_placeholder .markdown (full_response )
135129
136- except ValueError as e :
137- if "exceed context window" in str (e ):
138- st .error (
139- "The tokens in your query exceed the model's context window. Please try a different query mode or shorten your query."
140- )
141- # Optionally, you could reset the query mode or suggest alternatives
142- st .session_state .query_mode = "mix" # Default to "mix" mode
143- st .session_state .user_query = '' # Clear the user query
144- else :
145- st .error (f"Error processing query: { e } " )
146130 except Exception as e :
147131 st .error (f"Error processing query: { e } " )
148132
149- def query (query , mode = "mix" ):
150- # Run the async function in the event loop
151- import asyncio
152- asyncio .run (async_query (query , mode ))
153133
154134# Streamlit UI
155- st .title ("GraphRAG Chatbot " )
135+ st .title ("Graph RAG with PDF Upload " )
156136
157137uploaded_file = st .file_uploader ("Upload a PDF file" , type = "pdf" )
158138
159139if uploaded_file is not None :
160140 if uploaded_file .name != st .session_state .uploaded_file_previous :
161141 st .session_state .uploaded_file_previous = uploaded_file .name
162- if os .path .exists (WORKING_DIR ):
163- shutil .rmtree (WORKING_DIR , ignore_errors = True )
164- os .makedirs (WORKING_DIR )
165-
166- with open ("temp.pdf" , "wb" ) as f :
142+
143+ # Create working directory if it doesn't exist
144+ if not os .path .exists (WORKING_DIR ):
145+ os .makedirs (WORKING_DIR )
146+
147+ # Save uploaded file temporarily
148+ temp_pdf_path = os .path .join (WORKING_DIR , "temp.pdf" )
149+ with open (temp_pdf_path , "wb" ) as f :
167150 f .write (uploaded_file .getbuffer ())
168151
169152 try :
170153 with st .spinner ("Processing PDF..." ):
171- pdf_to_text ("temp.pdf" , "document.txt" )
172- with open ("document.txt" , "r" , encoding = "utf-8" ) as f :
173- rag .insert (f .read ())
174- st .session_state .rag_initialized = True
154+ text = pdf_to_text (temp_pdf_path )
155+
156+ documents = create_documents_from_text (text )
157+
158+ # Set up retriever and chain
159+ st .session_state .retriever = setup_retriever (documents )
160+ st .session_state .chain = setup_llm_chain (st .session_state .retriever )
161+
162+ st .success ("PDF processed successfully! You can now ask questions." )
163+
175164 except Exception as e :
176165 st .error (f"Error processing PDF: { e } " )
177166 finally :
178- if os .path .exists ("temp.pdf" ):
179- os .remove ("temp.pdf" )
180-
181- if st .session_state .rag_initialized :
182- query_mode = st .radio (
183- "Select query mode:" ,
184- options = ["local" , "global" , "naive" , "hybrid" , "mix" ],
185- index = 3 ,
186- key = "mode"
187- )
188- st .session_state .query_mode = query_mode
167+ # Clean up temporary file
168+ if os .path .exists (temp_pdf_path ):
169+ os .remove (temp_pdf_path )
170+
171+ # Query section
172+ if st .session_state .retriever is not None :
173+ st .subheader ("Ask a Question" )
189174
190- # Use a unique key for the text input to avoid conflicts
191- user_query = st .text_input ("Enter your query:" , key = "query_input" )
175+ st .text_input (
176+ "Enter your question about the document:" ,
177+ key = "query_input"
178+ )
179+ user_query = st .session_state .query_input
192180
193- if st .button ("Submit" ):
194- if user_query .strip ():
195- st .session_state .user_query = user_query
196- query (st .session_state .user_query , mode = st .session_state .query_mode )
181+ if user_query .strip () and user_query != st .session_state .user_query :
182+ st .session_state .user_query = user_query
183+ process_query (user_query )
0 commit comments