Skip to content

Commit 051c942

Browse files
committed
feat: resolve all documented limitations
1. NLI model fallback for contradiction detection - DeBERTa-based NLI model as fallback when rule-based is uncertain - Fixes weak modality/temporal detection - backend/core/bel/nli_detector.py 2. Response validator for LLM hallucination detection - Extracts claims from LLM responses - Checks claims against stored beliefs - Regenerates response if contradictions found - backend/chat/response_validator.py 3. Zero-shot query classifier for hybrid routing - Replaces regex patterns with BART-MNLI classifier - Better coverage for real-time query detection - Regex fallback when model unavailable - backend/llm/query_classifier.py 26 new tests added. 698 tests passing.
1 parent ad193cb commit 051c942

File tree

11 files changed

+1069
-52
lines changed

11 files changed

+1069
-52
lines changed

backend/chat/response_validator.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
# Author: Bradley R. Kinnard
2+
"""
3+
Response Validator - validates LLM responses against stored beliefs.
4+
Extracts claims from LLM output and checks for contradictions with the belief store.
5+
"""
6+
7+
import logging
8+
import re
9+
from dataclasses import dataclass, field
10+
from typing import TYPE_CHECKING
11+
12+
if TYPE_CHECKING:
13+
from ..core.models.belief import Belief
14+
15+
logger = logging.getLogger(__name__)
16+
17+
# lazy-loaded
18+
_nlp = None
19+
20+
21+
def _get_nlp():
22+
"""Lazy load spacy for claim extraction."""
23+
global _nlp
24+
if _nlp is not None:
25+
return _nlp
26+
try:
27+
import spacy
28+
_nlp = spacy.load("en_core_web_sm")
29+
return _nlp
30+
except Exception as e:
31+
logger.warning(f"spacy unavailable for claim extraction: {e}")
32+
return None
33+
34+
35+
@dataclass
36+
class ExtractedClaim:
37+
"""A factual claim extracted from LLM response."""
38+
text: str
39+
sentence: str
40+
confidence: float = 1.0 # how certain we are this is a factual claim
41+
is_hedged: bool = False # "might be", "could be", etc.
42+
43+
44+
@dataclass
45+
class ValidationResult:
46+
"""Result of validating LLM response against beliefs."""
47+
is_valid: bool = True
48+
contradictions: list[dict] = field(default_factory=list)
49+
claims_checked: int = 0
50+
flagged_claims: list[ExtractedClaim] = field(default_factory=list)
51+
52+
53+
# hedging phrases that reduce claim confidence
54+
HEDGE_PHRASES = frozenset([
55+
"might", "may", "could", "possibly", "perhaps", "maybe",
56+
"i think", "i believe", "it seems", "appears to",
57+
"not sure", "uncertain", "likely", "probably",
58+
"in my opinion", "as far as i know", "i'm not certain",
59+
])
60+
61+
# phrases indicating the LLM is citing stored beliefs (should trust these)
62+
CITATION_PHRASES = frozenset([
63+
"you mentioned", "you said", "you told me",
64+
"according to what you said", "based on what you told me",
65+
"from our conversation", "as you noted", "you indicated",
66+
])
67+
68+
69+
def extract_claims(response: str) -> list[ExtractedClaim]:
70+
"""
71+
Extract factual claims from LLM response text.
72+
Filters out questions, hedged statements, and non-factual content.
73+
"""
74+
claims = []
75+
nlp = _get_nlp()
76+
77+
if nlp is None:
78+
# fallback: sentence splitting without NLP
79+
sentences = re.split(r'[.!?]+', response)
80+
for sent in sentences:
81+
sent = sent.strip()
82+
if len(sent) < 10:
83+
continue
84+
if sent.endswith('?'):
85+
continue # skip questions
86+
87+
# check for hedging
88+
lower_sent = sent.lower()
89+
is_hedged = any(h in lower_sent for h in HEDGE_PHRASES)
90+
91+
# skip if citing user's beliefs
92+
if any(c in lower_sent for c in CITATION_PHRASES):
93+
continue
94+
95+
# basic factual claim detection: contains "is", "are", "was", "were", numbers
96+
is_factual = bool(re.search(r'\b(is|are|was|were|has|have|had|will|can|does|did)\b', lower_sent))
97+
is_factual = is_factual or bool(re.search(r'\d+', sent))
98+
99+
if is_factual:
100+
claims.append(ExtractedClaim(
101+
text=sent,
102+
sentence=sent,
103+
confidence=0.5 if is_hedged else 0.8,
104+
is_hedged=is_hedged,
105+
))
106+
return claims
107+
108+
# NLP-based extraction
109+
doc = nlp(response)
110+
111+
for sent in doc.sents:
112+
sent_text = sent.text.strip()
113+
if len(sent_text) < 10:
114+
continue
115+
116+
# skip questions
117+
if sent_text.endswith('?'):
118+
continue
119+
120+
lower_sent = sent_text.lower()
121+
122+
# skip if citing user's beliefs
123+
if any(c in lower_sent for c in CITATION_PHRASES):
124+
continue
125+
126+
# check for hedging
127+
is_hedged = any(h in lower_sent for h in HEDGE_PHRASES)
128+
129+
# check if sentence contains factual assertions
130+
has_verb = any(tok.pos_ == "VERB" for tok in sent)
131+
has_subj = any(tok.dep_ in ("nsubj", "nsubjpass") for tok in sent)
132+
has_entity = any(ent.label_ in ("PERSON", "ORG", "GPE", "DATE", "TIME", "MONEY", "QUANTITY", "PERCENT") for ent in sent.ents)
133+
has_number = any(tok.like_num for tok in sent)
134+
135+
# factual if has subject+verb or contains entities/numbers
136+
is_factual = (has_verb and has_subj) or has_entity or has_number
137+
138+
if is_factual:
139+
confidence = 0.5 if is_hedged else 0.9
140+
claims.append(ExtractedClaim(
141+
text=sent_text,
142+
sentence=sent_text,
143+
confidence=confidence,
144+
is_hedged=is_hedged,
145+
))
146+
147+
return claims
148+
149+
150+
def validate_response(
151+
response: str,
152+
beliefs: list["Belief"],
153+
contradiction_threshold: float = 0.6,
154+
) -> ValidationResult:
155+
"""
156+
Validate LLM response against stored beliefs.
157+
158+
Extracts claims from response and checks each against beliefs
159+
for contradictions.
160+
161+
Args:
162+
response: LLM response text
163+
beliefs: List of user beliefs to check against
164+
contradiction_threshold: Min confidence to flag contradiction
165+
166+
Returns:
167+
ValidationResult with any contradictions found
168+
"""
169+
from backend.core.bel.semantic_contradiction import check_contradiction
170+
171+
result = ValidationResult()
172+
173+
if not beliefs:
174+
return result # nothing to validate against
175+
176+
claims = extract_claims(response)
177+
result.claims_checked = len(claims)
178+
179+
if not claims:
180+
return result # no factual claims to check
181+
182+
# check each claim against each belief
183+
for claim in claims:
184+
if claim.is_hedged:
185+
continue # skip hedged claims
186+
187+
for belief in beliefs:
188+
contra_result = check_contradiction(claim.text, belief.content)
189+
190+
if contra_result.label == "contradiction" and contra_result.confidence >= contradiction_threshold:
191+
result.is_valid = False
192+
result.contradictions.append({
193+
"claim": claim.text,
194+
"belief_id": str(belief.id),
195+
"belief_content": belief.content,
196+
"confidence": contra_result.confidence,
197+
"reason_codes": contra_result.reason_codes,
198+
})
199+
result.flagged_claims.append(claim)
200+
logger.warning(
201+
f"LLM claim contradicts belief: '{claim.text[:50]}...' vs '{belief.content[:50]}...'"
202+
)
203+
204+
return result
205+
206+
207+
def get_correction_prompt(
208+
original_response: str,
209+
contradictions: list[dict],
210+
beliefs: list["Belief"],
211+
) -> str:
212+
"""
213+
Generate a prompt to correct LLM response that contradicted beliefs.
214+
"""
215+
belief_context = "\n".join(
216+
f"- {b.content} (confidence: {b.confidence:.0%})"
217+
for b in beliefs[:10]
218+
)
219+
220+
contradiction_details = "\n".join(
221+
f"- Your claim '{c['claim'][:60]}...' contradicts: '{c['belief_content'][:60]}...'"
222+
for c in contradictions[:5]
223+
)
224+
225+
return f"""Your previous response contained claims that contradict what the user has told you.
226+
227+
WHAT THE USER HAS TOLD YOU (trust these):
228+
{belief_context}
229+
230+
CONTRADICTIONS FOUND:
231+
{contradiction_details}
232+
233+
Please regenerate your response, ensuring you don't contradict the user's stated facts. If you're uncertain about something, acknowledge that uncertainty rather than stating incorrect facts.
234+
235+
Original response to fix:
236+
{original_response[:500]}..."""
237+
238+
239+
__all__ = [
240+
"ExtractedClaim",
241+
"ValidationResult",
242+
"extract_claims",
243+
"validate_response",
244+
"get_correction_prompt",
245+
]

backend/chat/service.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,62 @@ def get_or_create_session(self, session_id: Optional[UUID] = None) -> ChatSessio
120120
self._sessions[session.id] = session
121121
return session
122122

123+
async def _validate_and_correct_response(
124+
self,
125+
response: str,
126+
beliefs: list,
127+
llm,
128+
messages: list,
129+
max_retries: int = 1,
130+
) -> str:
131+
"""
132+
Validate LLM response against beliefs and correct if needed.
133+
134+
Returns original response if valid, or corrected response if contradictions found.
135+
"""
136+
from .response_validator import validate_response, get_correction_prompt
137+
138+
validation = validate_response(response, beliefs)
139+
140+
if validation.is_valid:
141+
return response
142+
143+
logger.warning(
144+
f"Response validation failed: {len(validation.contradictions)} contradictions found"
145+
)
146+
147+
# try to correct
148+
for attempt in range(max_retries):
149+
correction_prompt = get_correction_prompt(
150+
response, validation.contradictions, beliefs
151+
)
152+
153+
# append correction request
154+
corrected_messages = messages + [
155+
ChatMessage(role="assistant", content=response),
156+
ChatMessage(role="user", content=correction_prompt),
157+
]
158+
159+
corrected = await llm.chat(
160+
messages=corrected_messages,
161+
beliefs=beliefs,
162+
temperature=0.3, # lower temp for correction
163+
max_tokens=settings.llm_max_tokens,
164+
)
165+
166+
# validate corrected response
167+
revalidation = validate_response(corrected.content, beliefs)
168+
169+
if revalidation.is_valid:
170+
logger.info("Response corrected successfully")
171+
return corrected.content
172+
173+
logger.warning(f"Correction attempt {attempt + 1} still has contradictions")
174+
response = corrected.content
175+
176+
# give up - return last attempt with warning prefix
177+
return f"[Note: Response may contain inaccuracies]\n\n{response}"
178+
123179
async def process_message(
124180
self,
125181
message: str,
@@ -395,7 +451,15 @@ async def process_message(
395451
max_tokens=settings.llm_max_tokens,
396452
)
397453

398-
turn.assistant_message = response.content
454+
# Step 8: Validate response against beliefs (catch hallucinations)
455+
validated_response = await self._validate_and_correct_response(
456+
response.content,
457+
top_beliefs,
458+
llm,
459+
messages,
460+
)
461+
462+
turn.assistant_message = validated_response
399463
turn.beliefs_used = [b.id for b in top_beliefs]
400464
turn.duration_ms = (datetime.now(timezone.utc) - start).total_seconds() * 1000
401465

backend/core/bel/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
Proposition,
1515
RuleBasedContradictionDetector,
1616
)
17+
from .nli_detector import (
18+
check_contradiction_nli,
19+
classify_nli,
20+
is_nli_available,
21+
NLIResult,
22+
)
1723
from .snapshot_compression import compress_snapshot, decompress_snapshot
1824
from .snapshot_logger import log_snapshot
1925

@@ -33,4 +39,8 @@
3339
"ContradictionResult",
3440
"Proposition",
3541
"RuleBasedContradictionDetector",
42+
"check_contradiction_nli",
43+
"classify_nli",
44+
"is_nli_available",
45+
"NLIResult",
3646
]

0 commit comments

Comments
 (0)