Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ rrf_k: 60
max_gen_tokens: 400
seg_filter: null
chunk_mode : "sections"
model_path: "models/qwen2.5-1.5b-instruct-q5_k_m.gguf"
model_path: "qwen2.5-1.5b-instruct-q5_k_m.gguf"
recursive_chunk_size: 2000
recursive_overlap: 200
use_hyde: false
hyde_max_tokens: 600
hallucination_detection:
enabled: true
model_path: "KRLabsOrg/lettucedect-base-modernbert-en-v1"
threshold: 0.1
use_indexed_chunks: false
13 changes: 13 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ class QueryPlanConfig:

model_path: os.PathLike

# hallucination detection
hallucination_enabled: bool
hallucination_model_path: str
hallucination_threshold: float

# testing
system_prompt_mode: str
disable_chunks: bool
Expand Down Expand Up @@ -80,6 +85,11 @@ def pick(key, default=None):
seg_filter = pick("seg_filter", None),
model_path = pick("model_path", None),

# Hallucination Detection
hallucination_enabled = pick("hallucination_detection", {}).get("enabled", False),
hallucination_model_path = pick("hallucination_detection", {}).get("model_path", "KRLabsOrg/lettucedect-base-modernbert-en-v1"),
hallucination_threshold = pick("hallucination_detection", {}).get("threshold", 0.1),

# Testing
system_prompt_mode = pick("system_prompt_mode", "baseline"),
disable_chunks = pick("disable_chunks", False),
Expand Down Expand Up @@ -129,6 +139,9 @@ def to_dict(self) -> Dict[str, Any]:
"rerank_mode": self.rerank_mode,
"max_gen_tokens": self.max_gen_tokens,
"model_path": self.model_path,
"hallucination_enabled": self.hallucination_enabled,
"hallucination_model_path": self.hallucination_model_path,
"hallucination_threshold": self.hallucination_threshold,
"system_prompt_mode": self.system_prompt_mode,
"disable_chunks": self.disable_chunks,
"use_golden_chunks": self.use_golden_chunks,
Expand Down
4 changes: 2 additions & 2 deletions src/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from tqdm import tqdm

class SentenceTransformer:
def __init__(self, model_path: str, n_ctx: int = 40960, n_threads: int = None):
def __init__(self, model_path: str, n_ctx: int = 32768, n_threads: int = None):
"""
Initialize with a local GGUF model file path.
Args:
model_path: Path to your local .gguf file
n_ctx: Context window size (increased to match Qwen3 training context)
n_ctx: Context window size (default 32768 to match Qwen3 training context)
n_threads: Number of threads to use (None = auto-detect)
"""
print(f"Loading model with n_ctx={n_ctx}, n_threads={n_threads}")
Expand Down
99 changes: 99 additions & 0 deletions src/hallucination_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
Hallucination detection module for TokenSmith using LettuceDetect.
"""

from typing import List, Dict, Any, Optional
from lettucedetect.models.inference import HallucinationDetector


class HallucinationDetectorWrapper:
"""
Wrapper for LettuceDetect hallucination detection.
"""

def __init__(self, model_path: str = "KRLabsOrg/lettucedect-base-modernbert-en-v1", threshold: float = 0.1):
"""
Initialize the hallucination detector.

Args:
model_path: Path to the LettuceDetect model on HuggingFace
threshold: Threshold for considering answer as hallucinated (fraction of unsupported tokens)
"""
self.detector = HallucinationDetector(
method="transformer",
model_path=model_path,
)
self.threshold = threshold

def detect_hallucinations(self, question: str, answer: str, contexts: List[str]) -> Dict[str, Any]:
"""
Detect hallucinations in the answer given the question and contexts.

Args:
question: The question asked
answer: The generated answer
contexts: List of context strings (retrieved chunks)

Returns:
Dict with 'is_hallucinated' (bool), 'unsupported_fraction' (float), and 'hallucinated_spans' (list)
"""
try:
# Get predictions from LettuceDetect
predictions = self.detector.predict(
context=contexts,
question=question,
answer=answer,
output_format="spans"
)

# Calculate unsupported fraction
total_answer_tokens = len(answer.split()) # Rough token count
hallucinated_tokens = 0

hallucinated_spans = []
for pred in predictions:
span_text = pred['text']
span_tokens = len(span_text.split())
print(span_text)
hallucinated_tokens += span_tokens
hallucinated_spans.append({
'text': span_text,
'confidence': pred['confidence'],
'start': pred['start'],
'end': pred['end']
})

unsupported_fraction = hallucinated_tokens / total_answer_tokens if total_answer_tokens > 0 else 0
is_hallucinated = unsupported_fraction > self.threshold

return {
'is_hallucinated': is_hallucinated,
'unsupported_fraction': unsupported_fraction,
'hallucinated_spans': hallucinated_spans
}

except Exception as e:
# If detection fails, assume no hallucinations to avoid blocking
return {
'is_hallucinated': False,
'unsupported_fraction': 0.0,
'hallucinated_spans': [],
'error': str(e)
}


def create_detector(model_path: Optional[str] = None, threshold: float = 0.1) -> HallucinationDetectorWrapper:
"""
Factory function to create hallucination detector.

Args:
model_path: Path to model, defaults to base English model
threshold: Detection threshold

Returns:
Configured HallucinationDetectorWrapper
"""
if model_path is None:
model_path = "KRLabsOrg/lettucedect-base-modernbert-en-v1"

return HallucinationDetectorWrapper(model_path=model_path, threshold=threshold)
15 changes: 15 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from src.ranking.ranker import EnsembleRanker
from src.preprocessing.chunking import DocumentChunker
from src.retriever import apply_seg_filter, BM25Retriever, FAISSRetriever, load_artifacts
from src.hallucination_detector import create_detector
from src.query_enhancement import generate_hypothetical_document


Expand Down Expand Up @@ -225,6 +226,20 @@ def get_answer(
system_prompt_mode=system_prompt
)

# Step 5: Hallucination Detection (if enabled)
if cfg.hallucination_enabled and ranked_chunks:
detector = create_detector(
model_path=cfg.hallucination_model_path,
threshold=cfg.hallucination_threshold
)
context_texts = ranked_chunks # chunks are already strings
hallucination_result = detector.detect_hallucinations(question, ans, context_texts)

if hallucination_result['is_hallucinated']:
# Add warning to the answer
warning = f"\n\n<!!!> WARNING: This answer may contain hallucinations. {hallucination_result['unsupported_fraction']:.1%} of the content appears unsupported by the provided context."
ans += warning

if is_test_mode:
return ans, chunks_info, hyde_query
return ans
Expand Down
3 changes: 1 addition & 2 deletions src/preprocessing/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from dataclasses import dataclass
from typing import List, Tuple, Optional

from langchain.text_splitter import RecursiveCharacterTextSplitter

from langchain_text_splitters import RecursiveCharacterTextSplitter

# ------------------------ Section Guessing (metadata) -------------------

Expand Down