|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# SPDX-FileCopyrightText: 2023-present Oori Data <info@oori.dev> |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | +# demo/pg-hybrid/chat_doc_folder.py |
| 5 | +''' |
| 6 | +"Chat my docs" demo using PG hybrid search. Skill level: intermediate |
| 7 | +
|
| 8 | +Indexes a folder full of Word, PDF & Markdown documents, then query an LLM using these as context. |
| 9 | +
|
| 10 | +Vector store: PostgreSQL with pgvector - https://github.com/pgvector/pgvector |
| 11 | + Uses OgbujiPT's hybrid search combining dense vector search with sparse BM25 retrieval |
| 12 | +Text to vector (embedding) model: |
| 13 | + Alternatives: https://www.sbert.net/docs/pretrained_models.html / OpenAI ada002 |
| 14 | +PDF to text [PyPDF2](https://pypdf2.readthedocs.io/en/3.x/) |
| 15 | + Alternative: [Docling](https://github.com/DS4SD/docling) |
| 16 | +
|
| 17 | +Needs access to an OpenAI-like service. This can be private/self-hosted, though. |
| 18 | +OgbujiPT's sister project Toolio would work - https://github.com/OoriData/Toolio |
| 19 | +via e.g. llama-cpp-python, text-generation-webui, Ollama |
| 20 | +
|
| 21 | +Prerequisites, in addition to OgbujiPT (or you can just use the `mega` package): |
| 22 | +
|
| 23 | +```sh |
| 24 | +uv pip install fire sentence-transformers docx2python PyPDF2 PyCryptodome # or uv pip install -U ".[mega]" |
| 25 | +``` |
| 26 | +
|
| 27 | +PostgreSQL with pgvector must be running. See README.md in this directory for setup. |
| 28 | +
|
| 29 | +Assume for the following the LLM server is running on localhost, port 8000. |
| 30 | +
|
| 31 | +```sh |
| 32 | +python chat_doc_folder_pg.py --docs=../sample-docs --apibase=http://localhost:8000 |
| 33 | +``` |
| 34 | +
|
| 35 | +Sample query: "Tell me about the Calabar Kingdom" |
| 36 | +
|
| 37 | +You can always check the retrieval using --verbose |
| 38 | +
|
| 39 | +You can specify your document directory, and/or tweak it with the following command line options: |
| 40 | +--verbose - print more information while processing (for debugging) |
| 41 | +--limit (max number of chunks to retrieve for use as context) |
| 42 | +--chunk-size (characters per chunk, while prepping to create embeddings) |
| 43 | +--chunk-overlap (character overlap between chunks, while prepping to create embeddings) |
| 44 | +--question (The user question; if None (the default), prompt the user interactively) |
| 45 | +''' |
| 46 | +import os |
| 47 | +import asyncio |
| 48 | +from pathlib import Path |
| 49 | + |
| 50 | +import fire |
| 51 | +from docx2python import docx2python |
| 52 | +from PyPDF2 import PdfReader |
| 53 | +from sentence_transformers import SentenceTransformer |
| 54 | + |
| 55 | +from ogbujipt.llm.wrapper import openai_chat_api, prompt_to_chat |
| 56 | +from ogbujipt.text.splitter import text_split_fuzzy |
| 57 | +from ogbujipt.store.postgres import DataDB |
| 58 | +from ogbujipt.retrieval import BM25Search, HybridSearch, SimpleDenseSearch |
| 59 | + |
| 60 | +# Avoid re-entrace complaints from huggingface/tokenizers |
| 61 | +os.environ['TOKENIZERS_PARALLELISM'] = 'false' |
| 62 | + |
| 63 | +USER_PROMPT = 'What do you want to know from the documents?\n' |
| 64 | + |
| 65 | +# Database connection parameters (can be overridden via environment variables) |
| 66 | +PG_DB_NAME = os.environ.get('PG_DB_NAME', 'hybrid_demo') |
| 67 | +PG_DB_HOST = os.environ.get('PG_DB_HOST', 'localhost') |
| 68 | +PG_DB_PORT = int(os.environ.get('PG_DB_PORT', '5432')) |
| 69 | +PG_DB_USER = os.environ.get('PG_DB_USER', 'demo_user') |
| 70 | +PG_DB_PASSWORD = os.environ.get('PG_DB_PASSWORD', 'demo_pass_2025') |
| 71 | + |
| 72 | +# Default embedding model |
| 73 | +DEFAULT_EMBEDDING_MODEL = 'all-MiniLM-L6-v2' |
| 74 | + |
| 75 | + |
| 76 | +class VectorStore: |
| 77 | + '''Encapsulates PostgreSQL DataDB and hybrid search with chunking parameters''' |
| 78 | + def __init__(self, chunk_size, chunk_overlap, embedding_model, table_name='chat_doc_folder'): |
| 79 | + self.chunk_size = chunk_size |
| 80 | + self.chunk_overlap = chunk_overlap |
| 81 | + self.embedding_model = embedding_model |
| 82 | + self.table_name = table_name |
| 83 | + self.kb_db = None |
| 84 | + self.hybrid_search = None |
| 85 | + |
| 86 | + async def initialize(self): |
| 87 | + '''Initialize database connection and hybrid search''' |
| 88 | + # Connect to PostgreSQL |
| 89 | + self.kb_db = await DataDB.from_conn_params( |
| 90 | + embedding_model=self.embedding_model, |
| 91 | + table_name=self.table_name, |
| 92 | + db_name=PG_DB_NAME, |
| 93 | + host=PG_DB_HOST, |
| 94 | + port=PG_DB_PORT, |
| 95 | + user=PG_DB_USER, |
| 96 | + password=PG_DB_PASSWORD, |
| 97 | + itypes=['vector'], # Create HNSW index for fast vector search |
| 98 | + ifuncs=['cosine'] |
| 99 | + ) |
| 100 | + |
| 101 | + # Drop existing table if present (for clean demo) |
| 102 | + if await self.kb_db.table_exists(): |
| 103 | + await self.kb_db.drop_table() |
| 104 | + |
| 105 | + # Create fresh table |
| 106 | + await self.kb_db.create_table() |
| 107 | + |
| 108 | + # Initialize hybrid search |
| 109 | + self.hybrid_search = HybridSearch( |
| 110 | + strategies=[ |
| 111 | + SimpleDenseSearch(), # Dense vector search |
| 112 | + BM25Search(k1=1.5, b=0.75, epsilon=0.25) # Sparse BM25 search |
| 113 | + ], |
| 114 | + k=60 # RRF constant |
| 115 | + ) |
| 116 | + |
| 117 | + def text_split(self, text): |
| 118 | + return text_split_fuzzy(text, chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, separator='\n') |
| 119 | + |
| 120 | + async def update(self, chunks, metas): |
| 121 | + '''Insert chunks into the database''' |
| 122 | + content_list = [(chunk, meta) for chunk, meta in zip(chunks, metas)] |
| 123 | + await self.kb_db.insert_many(content_list) |
| 124 | + |
| 125 | + async def search(self, q, limit=None): |
| 126 | + '''Search using hybrid search and return content strings''' |
| 127 | + results = [] |
| 128 | + async for result in self.hybrid_search.execute( |
| 129 | + query=q, |
| 130 | + backends=[self.kb_db], |
| 131 | + limit=limit or 4 |
| 132 | + ): |
| 133 | + results.append(result.content) |
| 134 | + return results |
| 135 | + |
| 136 | + |
| 137 | +async def read_word_doc(fpath, store): |
| 138 | + '''Convert a single word doc to text, split into chunks & add these to vector store''' |
| 139 | + print('Processing as Word doc:', fpath) # e.g. 'path/to/file.docx' |
| 140 | + with docx2python(fpath) as docx_content: |
| 141 | + doctext = docx_content.text |
| 142 | + chunks = list(store.text_split(doctext)) |
| 143 | + metas = [{'source': str(fpath)}]*len(chunks) |
| 144 | + await store.update(chunks, metas=metas) |
| 145 | + |
| 146 | + |
| 147 | +async def read_pdf_doc(fpath, store): |
| 148 | + '''Convert a single PDF to text, split into chunks & add these to vector store''' |
| 149 | + print('Processing as PDF:', fpath) # e.g. 'path/to/file.pdf' |
| 150 | + pdf_reader = PdfReader(fpath) |
| 151 | + doctext = ''.join((page.extract_text() for page in pdf_reader.pages)) |
| 152 | + chunks = list(store.text_split(doctext)) |
| 153 | + metas = [{'source': str(fpath)}]*len(chunks) |
| 154 | + await store.update(chunks, metas=metas) |
| 155 | + |
| 156 | + |
| 157 | +async def read_text_or_markdown_doc(fpath, store): |
| 158 | + '''Split a single text or markdown file into chunks & add these to vector store''' |
| 159 | + print('Processing as text:', fpath) # e.g. 'path/to/file.txt' |
| 160 | + with open(fpath) as docx_content: |
| 161 | + doctext = docx_content.read() |
| 162 | + chunks = list(store.text_split(doctext)) |
| 163 | + metas = [{'source': str(fpath)}]*len(chunks) |
| 164 | + await store.update(chunks, metas=metas) |
| 165 | + |
| 166 | + |
| 167 | +async def async_main(oapi, docs, verbose, limit, chunk_size, chunk_overlap, question, embedding_model): |
| 168 | + store = VectorStore(chunk_size, chunk_overlap, embedding_model) |
| 169 | + await store.initialize() |
| 170 | + |
| 171 | + # Process all documents |
| 172 | + for fname in docs.iterdir(): |
| 173 | + # print(fname, fname.suffix) |
| 174 | + if fname.suffix in ['.doc', '.docx']: |
| 175 | + await read_word_doc(fname, store) |
| 176 | + elif fname.suffix == '.pdf': |
| 177 | + await read_pdf_doc(fname, store) |
| 178 | + elif fname.suffix in ['.txt', '.md', '.mdx']: |
| 179 | + await read_text_or_markdown_doc(fname, store) |
| 180 | + |
| 181 | + # Main chat loop |
| 182 | + done = False |
| 183 | + while not done: |
| 184 | + print('\n') |
| 185 | + if question: |
| 186 | + user_question = question |
| 187 | + else: |
| 188 | + user_question = input(USER_PROMPT) |
| 189 | + if user_question.strip() == 'done': |
| 190 | + break |
| 191 | + |
| 192 | + docs = await store.search(user_question, limit=limit) |
| 193 | + if verbose: |
| 194 | + print(docs) |
| 195 | + if docs: |
| 196 | + gathered_chunks = '\n\n'.join(docs) |
| 197 | + # Build system message with the approx nearest neighbor chunks as provided context |
| 198 | + # In practice we'd use word loom to load the propts, as demoed in multiprocess.py |
| 199 | + sys_prompt = '''\ |
| 200 | +You are a helpful assistant, who answers questions directly and as briefly as possible. |
| 201 | +Consider the following context and answer the user\'s question. |
| 202 | +If you cannot answer with the given context, just say so.\n\n''' |
| 203 | + sys_prompt += gathered_chunks + '\n\n' |
| 204 | + messages = prompt_to_chat(user_question, system=sys_prompt) |
| 205 | + if verbose: |
| 206 | + print('-'*80, '\n', messages, '\n', '-'*80) |
| 207 | + |
| 208 | + model_params = dict( |
| 209 | + max_tokens=1024, # Limit number of generated tokens |
| 210 | + top_p=1, # AKA nucleus sampling; can increase generated text diversity |
| 211 | + frequency_penalty=0, # Favor more or less frequent tokens |
| 212 | + presence_penalty=1, # Prefer new, previously unused tokens |
| 213 | + temperature=0.1) |
| 214 | + |
| 215 | + retval = await oapi(messages, **model_params) |
| 216 | + if verbose: |
| 217 | + print(type(retval)) |
| 218 | + print('\nFull response data from LLM:\n', retval) |
| 219 | + |
| 220 | + # just get back the text of the response |
| 221 | + print('\nResponse text from LLM:\n\n', retval.first_choice_text) |
| 222 | + |
| 223 | + |
| 224 | +def main( |
| 225 | + docs, |
| 226 | + verbose=False, |
| 227 | + chunk_size=200, |
| 228 | + chunk_overlap=20, |
| 229 | + limit=4, |
| 230 | + openai_key=None, |
| 231 | + apibase='http://127.0.0.1:8000', |
| 232 | + model='', |
| 233 | + question=None, |
| 234 | + embedding_model=DEFAULT_EMBEDDING_MODEL |
| 235 | +): |
| 236 | + ''' |
| 237 | + Chat with documents using PG hybrid search. |
| 238 | +
|
| 239 | + Args: |
| 240 | + docs: Path to directory containing documents (Word, PDF, Markdown, Text) |
| 241 | + verbose: Print more information while processing (for debugging) |
| 242 | + chunk_size: Number of characters to include per chunk |
| 243 | + chunk_overlap: Number of characters to overlap at the edges of chunks |
| 244 | + limit: Maximum number of chunks matched against the posed question to use as context for the LLM |
| 245 | + openai_key: OpenAI API key. Leave blank to specify self-hosted model via --apibase |
| 246 | + apibase: OpenAI API base URL (default: http://127.0.0.1:8000) |
| 247 | + model: OpenAI model to use (see https://platform.openai.com/docs/models). Use only with --openai-key |
| 248 | + question: The question to ask (or prompt for one if None) |
| 249 | + embedding_model: Sentence transformer model for embeddings (default: all-MiniLM-L6-v2) |
| 250 | + ''' |
| 251 | + docs_path = Path(docs) |
| 252 | + if not docs_path.exists() or not docs_path.is_dir(): |
| 253 | + raise ValueError(f'Document directory does not exist: {docs}') |
| 254 | + |
| 255 | + # Load embedding model |
| 256 | + print(f'\n📦 Loading embedding model: {embedding_model}…') |
| 257 | + embedding_model_instance = SentenceTransformer(embedding_model) |
| 258 | + print(' ✓ Model loaded!') |
| 259 | + |
| 260 | + # Use OpenAI API if specified, otherwise emulate with supplied URL info |
| 261 | + if openai_key: |
| 262 | + oapi = openai_chat_api(api_key=openai_key, model=(model or 'gpt-3.5-turbo')) |
| 263 | + else: |
| 264 | + oapi = openai_chat_api(model=model, base_url=apibase) |
| 265 | + |
| 266 | + asyncio.run(async_main(oapi, docs_path, verbose, limit, chunk_size, chunk_overlap, question, embedding_model_instance)) |
| 267 | + |
| 268 | + |
| 269 | +if __name__ == '__main__': |
| 270 | + fire.Fire(main) |
| 271 | + |
0 commit comments