Skip to content

Commit 598b134

Browse files
densumeshskeptrunedev
authored andcommitted
feature: add hallucination detection to server
1 parent c732c7e commit 598b134

File tree

21 files changed

+1165
-143
lines changed

21 files changed

+1165
-143
lines changed

frontends/dashboard/src/analytics/components/SingleRagInfo/index.tsx

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ export const SingleRAGQuery = (props: SingleRAGQueryProps) => {
7878
"M/d/yy h:mm a",
7979
)}
8080
</span>
81-
<dl class="m-auto mt-5 grid grid-cols-1 divide-y divide-gray-200 overflow-hidden rounded-lg bg-white shadow md:grid-cols-4 md:divide-x md:divide-y-0">
81+
<dl class="m-auto mt-5 grid grid-cols-1 divide-y divide-gray-200 overflow-hidden rounded-lg bg-white shadow md:grid-cols-5 md:divide-x md:divide-y-0">
8282
<DataSquare label="RAG Type" value={props.rag_data.rag_type} />
8383
<DataSquare
8484
label="Dataset"
@@ -105,6 +105,14 @@ export const SingleRAGQuery = (props: SingleRAGQueryProps) => {
105105
value={props.search_data?.top_score.toPrecision(4) ?? "N/A"}
106106
/>
107107
</Show>
108+
<Show when={props.rag_data && props.rag_data.hallucination_score}>
109+
<DataSquare
110+
label="Hallucination Score"
111+
value={
112+
props.rag_data.hallucination_score?.toPrecision(4) ?? "N/A"
113+
}
114+
/>
115+
</Show>
108116
<Show
109117
when={
110118
props.rag_data.query_rating &&
@@ -126,6 +134,18 @@ export const SingleRAGQuery = (props: SingleRAGQueryProps) => {
126134
</ul>
127135
</Card>
128136
</Show>
137+
<Show
138+
when={
139+
props.rag_data.detected_hallucinations &&
140+
props.rag_data.detected_hallucinations.length > 0
141+
}
142+
>
143+
<Card title="Detected Hallucinations">
144+
<ul>
145+
<li>{props.rag_data.detected_hallucinations?.join(",")}</li>
146+
</ul>
147+
</Card>
148+
</Show>
129149
<Show
130150
when={
131151
(props.search_data?.results && props.search_data.results[0]) ||

frontends/dashboard/src/analytics/pages/tablePages/RAGAnalyticsPage.tsx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ export const RAGAnalyticsPage = () => {
111111
);
112112
},
113113
},
114+
{
115+
accessorKey: "hallucination_score",
116+
header: "Hallucination Score",
117+
},
114118
{
115119
accessorKey: "query_rating",
116120
header: "Query Rating",

frontends/shared/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,8 @@ export interface RagQueryEvent {
634634
note?: string;
635635
rating: number;
636636
};
637+
hallucination_score?: number;
638+
detected_hallucinations?: string[];
637639
}
638640

639641
export interface EventData {

hallucination-detection/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

hallucination-detection/Cargo.toml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "hallucination-detection"
3-
version = "0.1.3"
3+
version = "0.1.5"
44
edition = "2021"
55
license = "MIT"
66
repository = "https://github.com/devflowinc/trieve"
@@ -11,9 +11,10 @@ description = "Extremely fast Hallucination Detection for RAG using BERT NER, no
1111
inherits = "release"
1212

1313
[features]
14-
default = []
14+
default = ["ner"]
1515
ner = ["rust-bert"]
16-
onnx = ["ort"]
16+
download-onnx = ["ort?/download-binaries"]
17+
onnx = ["ort", "rust-bert?/onnx"]
1718

1819
[dependencies]
1920
# Core dependencies
@@ -22,15 +23,14 @@ regex = "1.11.1"
2223
serde = { version = "1.0.215", features = ["derive"] }
2324
tokio = { version = "1.42.0", features = ["full"] }
2425
once_cell = "1.18"
25-
26-
# Optional dependencies for NER feature
27-
rust-bert = { version = "0.23.0", features = ["onnx"], optional = true }
26+
rust-bert = { version = "0.23.0", optional = true }
2827
ort = { version = "1.16.3", features = [
29-
"download-binaries",
3028
"load-dynamic",
31-
], optional = true }
29+
], optional = true, default-features = false }
30+
3231

3332
[dev-dependencies]
33+
ort = { version = "1.16.3", features = ["download-binaries", "load-dynamic"] }
3434
csv = "1.3.1"
3535
dotenvy = "0.15.7"
3636
openai_dive = "0.7.0"

hallucination-detection/README.md

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,32 @@ Add this to your `Cargo.toml`:
2525
hallucination-detection = "^0.1.3"
2626
```
2727

28-
If you want to use NER and ONNX features:
28+
If you want to use NER:
29+
30+
1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v2.4`: if this version is no longer available on the "get started" page, the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.4.0%2Bcpu.zip` for a Linux version with CPU.
31+
2. Extract the library to a location of your choice
32+
3. Set the following environment variables
33+
##### Linux:
34+
```bash
35+
export LIBTORCH=/path/to/libtorch
36+
export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH
37+
```
38+
##### Windows
39+
```powershell
40+
$Env:LIBTORCH = "X:\path\to\libtorch"
41+
$Env:Path += ";X:\path\to\libtorch\lib"
42+
```
2943

3044
```toml
3145
[dependencies]
46+
hallucination-detection = { version = "^0.1.3", features = ["ner"] }
47+
```
48+
49+
If you want to use ONNX for the NER models, you need to either [install the ort runtime](https://docs.rs/ort/1.16.3/ort/#how-to-get-binaries) or include it in your dependencies:
50+
51+
```toml
3252
hallucination-detection = { version = "^0.1.3", features = ["ner", "onnx"] }
53+
ort = { version = "...", features = [ "download-binaries" ] }
3354
```
3455

3556
## Quick Start

hallucination-detection/examples/rag_truth_test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ async fn run_hallucination_test() -> Result<(), Box<dyn Error>> {
131131
let start = std::time::Instant::now();
132132
let hallucination_score = detector
133133
.detect_hallucinations(&record.response, &[source_info.clone()])
134-
.await;
134+
.await.unwrap();
135135
let elapsed = start.elapsed();
136136
println!("Hallucination detection took: {:?}", elapsed);
137137

hallucination-detection/examples/vectara_test.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ async fn run_hallucination_test() -> Result<(), Box<dyn Error>> {
8888
let start = std::time::Instant::now();
8989
let hallucination_score = detector
9090
.detect_hallucinations(&record.og_sum, &references)
91-
.await;
91+
.await
92+
.unwrap();
9293
let elapsed = start.elapsed();
9394
println!("Hallucination detection took: {:?}", elapsed);
9495

hallucination-detection/src/lib.rs

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,32 @@ use std::{
77
};
88
use tokio::sync::OnceCell;
99

10+
#[cfg(not(feature = "ner"))]
11+
#[cfg(feature = "onnx")]
12+
compile_error!("NER feature must be enabled to use ONNX model");
13+
1014
#[cfg(feature = "ner")]
1115
use {
1216
rust_bert::{
13-
pipelines::{
14-
common::{ModelResource, ModelType, ONNXModelResources},
15-
ner::{Entity, NERModel},
16-
token_classification::{LabelAggregationOption, TokenClassificationConfig},
17-
},
18-
resources::RemoteResource,
17+
pipelines::ner::{Entity, NERModel},
18+
pipelines::token_classification::TokenClassificationConfig,
1919
RustBertError,
2020
},
21+
std::error::Error,
2122
std::sync::mpsc,
2223
tokio::{sync::oneshot, task::JoinHandle},
2324
};
2425

26+
#[cfg(feature = "onnx")]
27+
use rust_bert::{
28+
pipelines::{
29+
common::ONNXModelResources,
30+
common::{ModelResource, ModelType},
31+
token_classification::LabelAggregationOption,
32+
},
33+
resources::RemoteResource,
34+
};
35+
2536
const WORDS_URL: &str =
2637
"https://raw.githubusercontent.com/dwyl/english-words/refs/heads/master/words.txt";
2738
const CACHE_FILE: &str = "~/.cache/hallucination-detection/english_words_cache.txt";
@@ -72,6 +83,18 @@ pub struct HallucinationScore {
7283
pub detected_hallucinations: Vec<String>,
7384
}
7485

86+
#[derive(Debug)]
87+
#[allow(dead_code)]
88+
pub struct DetectorError {
89+
message: String,
90+
}
91+
92+
impl std::fmt::Display for DetectorError {
93+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
94+
write!(f, "Detector Error: {}", self.message)
95+
}
96+
}
97+
7598
#[derive(Debug, Clone)]
7699
pub struct ScoreWeights {
77100
pub proper_noun_weight: f64,
@@ -201,13 +224,20 @@ impl HallucinationDetector {
201224
&self,
202225
llm_output: &String,
203226
references: &[String],
204-
) -> HallucinationScore {
227+
) -> Result<HallucinationScore, DetectorError> {
205228
let mut all_texts = vec![llm_output.to_string()];
206229
all_texts.extend(references.iter().cloned());
207230

208-
let all_analyses = self.analyze_text(&all_texts).await;
231+
let all_analyses = self.analyze_text(&all_texts).await?;
209232

210-
let (output_analysis, ref_analyses) = all_analyses.split_first().unwrap();
233+
let (output_analysis, ref_analyses) = match all_analyses.split_first() {
234+
Some((output_analysis, ref_analyses)) => (output_analysis, ref_analyses),
235+
None => {
236+
return Err(DetectorError {
237+
message: "Failed to analyze text".to_string(),
238+
});
239+
}
240+
};
211241

212242
let all_ref_proper_nouns: HashSet<_> = ref_analyses
213243
.iter()
@@ -250,7 +280,7 @@ impl HallucinationDetector {
250280
+ number_mismatch_score * self.options.weights.number_mismatch_weight)
251281
.clamp(0.0, 1.0);
252282

253-
HallucinationScore {
283+
Ok(HallucinationScore {
254284
proper_noun_score,
255285
unknown_word_score,
256286
number_mismatch_score,
@@ -261,14 +291,19 @@ impl HallucinationDetector {
261291
number_diff.iter().map(|n| n.to_string()).collect(),
262292
]
263293
.concat(),
264-
}
294+
})
265295
}
266296

267297
#[allow(unused_variables)]
268-
async fn analyze_text(&self, texts: &[String]) -> Vec<TextAnalysis> {
298+
async fn analyze_text(&self, texts: &[String]) -> Result<Vec<TextAnalysis>, DetectorError> {
269299
#[cfg(feature = "ner")]
270300
let entities = if let Some(ner_model) = &self.ner_model {
271-
ner_model.predict(texts.to_vec()).await.unwrap()
301+
ner_model
302+
.predict(texts.to_vec())
303+
.await
304+
.map_err(|e| DetectorError {
305+
message: format!("Failed to predict entities: {:?}", e),
306+
})?
272307
} else {
273308
vec![Vec::new(); texts.len()]
274309
};
@@ -322,11 +357,11 @@ impl HallucinationDetector {
322357
true
323358
});
324359
}
325-
TextAnalysis {
360+
Ok(TextAnalysis {
326361
proper_nouns,
327362
unknown_words,
328363
numbers,
329-
}
364+
})
330365
})
331366
.collect()
332367
}
@@ -381,7 +416,8 @@ mod tests {
381416

382417
let score = detector
383418
.detect_hallucinations(&llm_output, &references)
384-
.await;
419+
.await
420+
.unwrap();
385421
println!("Zero Hallucination Score: {:?}", score);
386422

387423
assert_eq!(score.proper_noun_score, 0.0);
@@ -402,7 +438,8 @@ mod tests {
402438

403439
let score = detector
404440
.detect_hallucinations(&llm_output, &references)
405-
.await;
441+
.await
442+
.unwrap();
406443
println!("Multiple References Score: {:?}", score);
407444
assert_eq!(score.proper_noun_score, 0.0); // Both companies are in references
408445
assert_eq!(score.number_mismatch_score, 0.0); // Number matches reference
@@ -415,13 +452,15 @@ mod tests {
415452
// Empty input
416453
let score_empty = detector
417454
.detect_hallucinations(&String::from(""), &[String::from("")])
418-
.await;
455+
.await
456+
.unwrap();
419457
assert_eq!(score_empty.total_score, 0.0);
420458

421459
// Only numbers
422460
let score_numbers = detector
423461
.detect_hallucinations(&String::from("123 456.789"), &[String::from("123 456.789")])
424-
.await;
462+
.await
463+
.unwrap();
425464
assert_eq!(score_numbers.number_mismatch_score, 0.0);
426465

427466
// Only proper nouns
@@ -430,7 +469,8 @@ mod tests {
430469
&String::from("John Smith"),
431470
&[String::from("Different Person")],
432471
)
433-
.await;
472+
.await
473+
.unwrap();
434474
assert!(score_nouns.proper_noun_score > 0.0);
435475
}
436476

@@ -508,7 +548,8 @@ mod tests {
508548
&String::from(llm_output),
509549
&references.into_iter().map(String::from).collect::<Vec<_>>(),
510550
)
511-
.await;
551+
.await
552+
.unwrap();
512553

513554
println!("Test '{}' Score: {:?}", test_name, score);
514555

pdf2md/server/Cargo.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)