Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
Binary file added DBI_Final_Report.pdf
Binary file not shown.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,8 @@ run-chat:
@echo "Note: Chat mode requires interactive terminal. If this fails, use:"
@echo " conda activate tokensmith && python -m src.main chat $(ARGS)"
conda run --no-capture-output -n tokensmith --no-capture-output python -m src.main chat $(ARGS)

run-ui:
@echo "Starting TokenSmith Web UI..."
@echo "The UI will open in your browser at http://localhost:8501"
conda run --no-capture-output -n tokensmith streamlit run src/ui.py
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,19 @@ make run-index ARGS="--pdf_range 1-10 --chunk_mode chars --visualize"

### 6) Chat

**Option A: Web UI (Recommended)**
```shell
make run-ui
# or
streamlit run src/ui.py
```
Opens a modern web interface in your browser with:
- Clean chat interface with message history
- Multiple chat sessions (create new chats anytime)
- Citation display with page numbers and sections
- Settings panel for answer style customization

**Option B: Command Line**
```shell
python -m src.main chat
```
Expand Down
25 changes: 24 additions & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,27 @@ recursive_chunk_size: 2000
recursive_overlap: 200
use_hyde: false
hyde_max_tokens: 600
use_indexed_chunks: false
use_indexed_chunks: false

# Indexing performance settings
indexing:
use_parallel_embedding: true
num_workers: null # Auto-detect if null
batch_size: 32
use_incremental: true
cache_dir: "index/.cache"

# Contextual retrieval settings
contextual_retrieval:
enabled: true
expansion_window: 2
decay_factor: 0.5
cross_ref_boost: 1.3

# Query planner settings
use_query_planner: true

# Conversation memory settings
conversation:
enabled: true
max_history: 5 # 5 turns = 10 messages
3 changes: 3 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ dependencies:
- rank_bm25
- langchain
- pytest
- langchain-text-splitters # Add this line
- docling
- streamlit
Binary file added index/.DS_Store
Binary file not shown.
59 changes: 59 additions & 0 deletions scripts/check_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env python3
"""
Quick script to check if metadata exists in the index.
"""

import pathlib
import pickle
import sys

def check_metadata():
"""Check if metadata file exists and has content."""
artifacts_dir = pathlib.Path("index/sections")
index_prefix = "textbook_index"
metadata_path = artifacts_dir / f"{index_prefix}_meta.pkl"

if not metadata_path.exists():
print("❌ Metadata file NOT found!")
print(f" Path: {metadata_path}")
print("\n To enable citations, rebuild the index:")
print(" make run-index")
return False

try:
metadata = pickle.load(open(metadata_path, "rb"))
if not metadata or len(metadata) == 0:
print("⚠️ Metadata file exists but is EMPTY!")
print(f" Path: {metadata_path}")
print("\n Rebuild the index to populate metadata:")
print(" make run-index")
return False

# Check if metadata has actual content
sample = metadata[0] if metadata else {}
has_content = sample.get('page_number') or sample.get('section') or sample.get('chapter')

if has_content:
print(f"✅ Metadata file found with {len(metadata)} entries")
print(f" Path: {metadata_path}")
print(f" Sample entry keys: {list(sample.keys())}")
if sample.get('page_number'):
print(f" Sample page number: {sample.get('page_number')}")
return True
else:
print("⚠️ Metadata file exists but entries are EMPTY!")
print(f" Path: {metadata_path}")
print("\n Rebuild the index to populate metadata:")
print(" make run-index")
return False

except Exception as e:
print(f"❌ Error reading metadata file: {e}")
return False

if __name__ == "__main__":
success = check_metadata()
sys.exit(0 if success else 1)



86 changes: 86 additions & 0 deletions src/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,38 @@
from typing import List, Union
from llama_cpp import Llama
from tqdm import tqdm
import multiprocessing
from multiprocessing import Pool, cpu_count
import os


def _encode_worker(texts: List[str], model_path: str, batch_size: int, embedding_dim: int) -> np.ndarray:
"""
Worker function that loads model and encodes texts.
Each worker process gets its own model instance.
"""
# Load model in worker process
model = Llama(
model_path=model_path,
n_ctx=40960,
embedding=True,
verbose=False,
n_batch=512,
use_mmap=True,
logits_all=True
)

embeddings = []
for text in texts:
try:
emb = model.create_embedding(text)['data'][0]['embedding']
embeddings.append(emb)
except Exception as e:
print(f"Error encoding text in worker: {e}")
embeddings.append([0.0] * embedding_dim)

return np.array(embeddings, dtype=np.float32)


class SentenceTransformer:
def __init__(self, model_path: str, n_ctx: int = 40960, n_threads: int = None):
Expand All @@ -15,6 +47,7 @@ def __init__(self, model_path: str, n_ctx: int = 40960, n_threads: int = None):
"""
print(f"Loading model with n_ctx={n_ctx}, n_threads={n_threads}")

self._model_path = model_path # Store for parallel encoding
self.model = Llama(
model_path=model_path,
n_ctx=n_ctx,
Expand Down Expand Up @@ -99,6 +132,59 @@ def encode(self,
vecs = vecs / norms

return vecs

def encode_parallel(self,
texts: List[str],
num_workers: int = None,
batch_size: int = 32,
show_progress_bar: bool = True) -> np.ndarray:
"""
Encode texts in parallel using multiprocessing.
Each worker gets its own model instance.

Args:
texts: List of texts to encode
num_workers: Number of worker processes (None = auto-detect)
batch_size: Batch size per worker
show_progress_bar: Whether to show progress bar

Returns:
numpy.ndarray: Float32 embeddings array
"""
if not texts:
return np.array([], dtype=np.float32).reshape(0, -1)

if num_workers is None:
num_workers = max(1, min(cpu_count() - 1, len(texts) // 10))
num_workers = max(1, num_workers) # At least 1 worker

# For small batches, use regular encoding
if len(texts) < 100 or num_workers == 1:
return self.encode(texts, batch_size=batch_size, show_progress_bar=show_progress_bar)

# Split texts into worker chunks
chunk_size = max(1, len(texts) // num_workers)
text_chunks = [texts[i:i+chunk_size] for i in range(0, len(texts), chunk_size)]

# Get model path from the model object (we need to store it)
model_path = getattr(self, '_model_path', None)
if model_path is None:
# Fallback to regular encoding if we can't get model path
print("Warning: Cannot determine model path for parallel encoding. Using sequential encoding.")
return self.encode(texts, batch_size=batch_size, show_progress_bar=show_progress_bar)

print(f"Encoding {len(texts)} texts using {num_workers} workers...")

# Create worker pool
worker_args = [(chunk, model_path, batch_size, self.embedding_dimension)
for chunk in text_chunks]

with Pool(num_workers) as pool:
results = pool.starmap(_encode_worker, worker_args)

# Concatenate results
all_embeddings = np.vstack(results)
return all_embeddings

def embed_one(self, text: str, normalize: bool = False) -> List[float]:
"""Encode single text and return as list."""
Expand Down
83 changes: 66 additions & 17 deletions src/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,31 +120,86 @@ def get_system_prompt(mode="tutor"):
return prompts.get(mode)


def format_prompt(chunks, query, max_chunk_chars=400, system_prompt_mode="tutor"):
def format_prompt(chunks, query, max_chunk_chars=400, system_prompt_mode="tutor", chunk_metadata=None, conversation_history=None):
"""
Format prompt for LLM with chunks and query.
Format prompt for LLM with chunks, query, metadata, and conversation history.

Args:
chunks: List of text chunks (can be empty for baseline)
query: User question
max_chunk_chars: Maximum characters per chunk
system_prompt_mode: System prompt mode (baseline, tutor, concise, detailed)
chunk_metadata: List of metadata dicts for each chunk (optional)
conversation_history: List of previous conversation turns (optional)
"""
# Get system prompt
system_prompt = get_system_prompt(system_prompt_mode)
system_section = f"<|im_start|>system\n{system_prompt}\n<|im_end|>\n" if system_prompt else ""

# Add citation instruction if metadata is available
citation_instruction = ""
if chunk_metadata and len(chunk_metadata) > 0:
citation_instruction = "\nIMPORTANT: When answering, cite your sources using [Page X, Chapter Y, Section Z] notation. Include citations for all information you use from the textbook excerpts.\n"

if system_prompt:
system_prompt = system_prompt + citation_instruction
system_section = f"<|im_start|>system\n{system_prompt}\n<|im_end|>\n"
else:
system_section = ""

# Build conversation history if provided
history_section = ""
if conversation_history and len(conversation_history) > 0:
for turn in conversation_history:
role = turn.get("role", "user")
content = turn.get("content", "")
if role == "user":
history_section += f"<|im_start|>user\n{content}\n<|im_end|>\n"
elif role == "assistant":
history_section += f"<|im_start|>assistant\n{content}\n<|im_end|>\n"

# Build prompt based on whether chunks are provided
if chunks and len(chunks) > 0:
trimmed = [(c or "")[:max_chunk_chars] for c in chunks]
context = "\n\n".join(trimmed)
# Build context with detailed citations
context_parts = []
for idx, chunk in enumerate(chunks):
trimmed = (chunk or "")[:max_chunk_chars]

# Add detailed citation if metadata available
if chunk_metadata and idx < len(chunk_metadata):
meta = chunk_metadata[idx]
page_num = meta.get('page_number')
chapter = meta.get('chapter', 0)
section_hierarchy = meta.get('section_hierarchy', {})
section = meta.get('section', 'Unknown')

# Build citation string
citation_parts = []
if page_num:
citation_parts.append(f"Page {page_num}")
if chapter > 0:
citation_parts.append(f"Chapter {chapter}")
if section_hierarchy.get('section', 0) > 0:
section_str = f"{section_hierarchy['section']}"
if section_hierarchy.get('subsection', 0) > 0:
section_str += f".{section_hierarchy['subsection']}"
citation_parts.append(f"Section {section_str}")

if citation_parts:
citation = f"[Source: {', '.join(citation_parts)}]"
context_parts.append(f"{citation}\n{trimmed}")
else:
context_parts.append(trimmed)
else:
context_parts.append(trimmed)

context = "\n\n".join(context_parts)
context = text_cleaning(context)

# Build prompt with chunks
context_section = f"Textbook Excerpts:\n{context}\n\n\n"

return textwrap.dedent(f"""\
{system_section}<|im_start|>user
{system_section}{history_section}<|im_start|>user
{context_section}Question: {query}
<|im_end|>
<|im_start|>assistant
Expand All @@ -155,7 +210,7 @@ def format_prompt(chunks, query, max_chunk_chars=400, system_prompt_mode="tutor"
question_label = "Question: " if system_prompt else ""

return textwrap.dedent(f"""\
{system_section}<|im_start|>user
{system_section}{history_section}<|im_start|>user
{question_label}{query}
<|im_end|>
<|im_start|>assistant
Expand Down Expand Up @@ -212,17 +267,11 @@ def _dedupe_sentences(text: str) -> str:
cleaned.append(s)
return " ".join(cleaned)

def answer(query: str, chunks, model_path: str, max_tokens: int = 300, **kw):
prompt = format_prompt(chunks, query)
approx_tokens = max(1, len(prompt) // 4)
#print(f"\n⚙️ Prompt length ≈ {approx_tokens} tokens\n")
raw = run_llama_cpp(prompt, model_path, max_tokens=max_tokens, **kw)
return _dedupe_sentences(raw)

def answer(query: str, chunks, model_path: str, max_tokens: int = 300,
system_prompt_mode: str = "tutor", **kw):
prompt = format_prompt(chunks, query, system_prompt_mode=system_prompt_mode)
system_prompt_mode: str = "tutor", chunk_metadata=None, conversation_history=None, **kw):
prompt = format_prompt(chunks, query, system_prompt_mode=system_prompt_mode,
chunk_metadata=chunk_metadata, conversation_history=conversation_history)
# approx_tokens = max(1, len(prompt) // 4)
#print(f"\n⚙️ Prompt length ≈ {approx_tokens} tokens (mode: {system_prompt_mode})\n")
#print(f"\nPrompt length ≈ {approx_tokens} tokens (mode: {system_prompt_mode})\n")
raw = run_llama_cpp(prompt, model_path, max_tokens=max_tokens, **kw)
return _dedupe_sentences(raw)
Loading