1
- from typing import List , Dict
2
- from transformers import pipeline
1
+ from typing import TYPE_CHECKING
2
+
3
+ if TYPE_CHECKING :
4
+ from transformers import pipeline as PipelineType
5
+
6
+ # Lazy import for runtime
7
+ try :
8
+ from transformers import pipeline
9
+ except ImportError :
10
+ pipeline = None
11
+
3
12
4
13
class HallucinationDetector :
5
- """
6
- Simple Hallucination Detector using NLI models (e.g., facebook/bart-large-mnli).
14
+ """Simple Hallucination Detector using NLI models (e.g., facebook/bart-large-mnli).
7
15
- Extract claims (basic sentence split)
8
16
- Verify claims against evidence docs using NLI
9
17
- Compute hallucination rate
10
18
"""
11
19
12
20
def __init__ (self , model_name : str = "facebook/bart-large-mnli" ):
21
+ if pipeline is None :
22
+ raise ImportError (
23
+ "The 'transformers' package is required for HallucinationDetector. "
24
+ "Install it with `pip install transformers`."
25
+ )
13
26
self .nli_pipeline = pipeline ("text-classification" , model = model_name )
14
27
15
- def extract_claims (self , text : str ) -> List [str ]:
28
+ def extract_claims (self , text : str ) -> list [str ]:
16
29
"""Naive sentence-based claim extraction"""
17
30
return [c .strip () for c in text .split ("." ) if c .strip ()]
18
31
@@ -21,14 +34,20 @@ def verify_claim(self, claim: str, evidence: str) -> bool:
21
34
result = self .nli_pipeline (f"{ claim } </s></s> { evidence } " )
22
35
return result [0 ]["label" ].lower () == "entailment"
23
36
24
- def verify_claim_multi (self , claim : str , evidence_docs : List [str ]) -> bool :
37
+ def verify_claim_multi (self , claim : str , evidence_docs : list [str ]) -> bool :
25
38
"""A claim is supported if any evidence doc entails it"""
26
39
return any (self .verify_claim (claim , e ) for e in evidence_docs )
27
40
28
- def compute_hallucination_rate (self , text : str , evidence_docs : List [str ]) -> Dict [str , float ]:
41
+ def compute_hallucination_rate (
42
+ self , text : str , evidence_docs : list [str ]
43
+ ) -> dict [str , float ]:
29
44
claims = self .extract_claims (text )
30
45
if not claims :
31
- return {"total_claims" : 0 , "unsupported_claims" : 0 , "hallucination_rate" : 0.0 }
46
+ return {
47
+ "total_claims" : 0 ,
48
+ "unsupported_claims" : 0 ,
49
+ "hallucination_rate" : 0.0 ,
50
+ }
32
51
33
52
unsupported = sum (not self .verify_claim_multi (c , evidence_docs ) for c in claims )
34
53
return {
0 commit comments