Skip to content

Commit ccf2494

Browse files
soguepamelafox
andauthored
Add minimum score criteria for AI search results (#1417)
* Add minimum score criteria for AI search results * Adjust input to support precise filtering in different search modes. * Resolve comparison issue * Update class style * Fix parsing * Add test * Lint * Format * Fix tests --------- Co-authored-by: Pamela Fox <[email protected]>
1 parent 40e9887 commit ccf2494

File tree

11 files changed

+186
-5
lines changed

11 files changed

+186
-5
lines changed

app/backend/approaches/approach.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ async def search(
123123
vectors: List[VectorQuery],
124124
use_semantic_ranker: bool,
125125
use_semantic_captions: bool,
126+
minimum_search_score: Optional[float],
127+
minimum_reranker_score: Optional[float],
126128
) -> List[Document]:
127129
# Use semantic ranker if requested and if retrieval mode is text or hybrid (vectors + text)
128130
if use_semantic_ranker and query_text:
@@ -161,7 +163,17 @@ async def search(
161163
reranker_score=document.get("@search.reranker_score"),
162164
)
163165
)
164-
return documents
166+
167+
qualified_documents = [
168+
doc
169+
for doc in documents
170+
if (
171+
(doc.score or 0) >= (minimum_search_score or 0)
172+
and (doc.reranker_score or 0) >= (minimum_reranker_score or 0)
173+
)
174+
]
175+
176+
return qualified_documents
165177

166178
def get_sources_content(
167179
self, results: List[Document], use_semantic_captions: bool, use_image_citation: bool

app/backend/approaches/chatreadretrieveread.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ async def run_until_final_call(
8989
has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
9090
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
9191
top = overrides.get("top", 3)
92+
minimum_search_score = overrides.get("minimum_search_score", 0.0)
93+
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)
94+
9295
filter = self.build_filter(overrides, auth_claims)
9396
use_semantic_ranker = True if overrides.get("semantic_ranker") and has_text else False
9497

@@ -149,7 +152,16 @@ async def run_until_final_call(
149152
if not has_text:
150153
query_text = None
151154

152-
results = await self.search(top, query_text, filter, vectors, use_semantic_ranker, use_semantic_captions)
155+
results = await self.search(
156+
top,
157+
query_text,
158+
filter,
159+
vectors,
160+
use_semantic_ranker,
161+
use_semantic_captions,
162+
minimum_search_score,
163+
minimum_reranker_score,
164+
)
153165

154166
sources_content = self.get_sources_content(results, use_semantic_captions, use_image_citation=False)
155167
content = "\n".join(sources_content)

app/backend/approaches/chatreadretrievereadvision.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ async def run_until_final_call(
8787
vector_fields = overrides.get("vector_fields", ["embedding"])
8888
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
8989
top = overrides.get("top", 3)
90+
minimum_search_score = overrides.get("minimum_search_score", 0.0)
91+
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)
9092
filter = self.build_filter(overrides, auth_claims)
9193
use_semantic_ranker = True if overrides.get("semantic_ranker") and has_text else False
9294

@@ -134,7 +136,16 @@ async def run_until_final_call(
134136
if not has_text:
135137
query_text = None
136138

137-
results = await self.search(top, query_text, filter, vectors, use_semantic_ranker, use_semantic_captions)
139+
results = await self.search(
140+
top,
141+
query_text,
142+
filter,
143+
vectors,
144+
use_semantic_ranker,
145+
use_semantic_captions,
146+
minimum_search_score,
147+
minimum_reranker_score,
148+
)
138149
sources_content = self.get_sources_content(results, use_semantic_captions, use_image_citation=True)
139150
content = "\n".join(sources_content)
140151

app/backend/approaches/retrievethenread.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ async def run(
8686

8787
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
8888
top = overrides.get("top", 3)
89+
minimum_search_score = overrides.get("minimum_search_score", 0.0)
90+
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)
8991
filter = self.build_filter(overrides, auth_claims)
9092
# If retrieval mode includes vectors, compute an embedding for the query
9193
vectors: list[VectorQuery] = []
@@ -95,7 +97,16 @@ async def run(
9597
# Only keep the text query if the retrieval mode uses text, otherwise drop it
9698
query_text = q if has_text else None
9799

98-
results = await self.search(top, query_text, filter, vectors, use_semantic_ranker, use_semantic_captions)
100+
results = await self.search(
101+
top,
102+
query_text,
103+
filter,
104+
vectors,
105+
use_semantic_ranker,
106+
use_semantic_captions,
107+
minimum_search_score,
108+
minimum_reranker_score,
109+
)
99110

100111
user_content = [q]
101112

app/backend/approaches/retrievethenreadvision.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ async def run(
8989

9090
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
9191
top = overrides.get("top", 3)
92+
minimum_search_score = overrides.get("minimum_search_score", 0.0)
93+
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)
9294
filter = self.build_filter(overrides, auth_claims)
9395
use_semantic_ranker = overrides.get("semantic_ranker") and has_text
9496

@@ -107,7 +109,16 @@ async def run(
107109
# Only keep the text query if the retrieval mode uses text, otherwise drop it
108110
query_text = q if has_text else None
109111

110-
results = await self.search(top, query_text, filter, vectors, use_semantic_ranker, use_semantic_captions)
112+
results = await self.search(
113+
top,
114+
query_text,
115+
filter,
116+
vectors,
117+
use_semantic_ranker,
118+
use_semantic_captions,
119+
minimum_search_score,
120+
minimum_reranker_score,
121+
)
111122

112123
image_list: list[ChatCompletionContentPartImageParam] = []
113124
user_content: list[ChatCompletionContentPartParam] = [{"text": q, "type": "text"}]

app/frontend/src/api/models.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ export type ChatAppRequestOverrides = {
2323
exclude_category?: string;
2424
top?: number;
2525
temperature?: number;
26+
minimum_search_score?: number;
27+
minimum_reranker_score?: number;
2628
prompt_template?: string;
2729
prompt_template_prefix?: string;
2830
prompt_template_suffix?: string;

app/frontend/src/pages/ask/Ask.module.css

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
}
5757

5858
.askSettingsSeparator {
59+
display: flex;
60+
flex-direction: column;
5961
margin-top: 15px;
6062
}
6163

app/frontend/src/pages/ask/Ask.tsx

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ export function Component(): JSX.Element {
2222
const [promptTemplatePrefix, setPromptTemplatePrefix] = useState<string>("");
2323
const [promptTemplateSuffix, setPromptTemplateSuffix] = useState<string>("");
2424
const [temperature, setTemperature] = useState<number>(0.3);
25+
const [minimumRerankerScore, setMinimumRerankerScore] = useState<number>(0);
26+
const [minimumSearchScore, setMinimumSearchScore] = useState<number>(0);
2527
const [retrievalMode, setRetrievalMode] = useState<RetrievalMode>(RetrievalMode.Hybrid);
2628
const [retrieveCount, setRetrieveCount] = useState<number>(3);
2729
const [useSemanticRanker, setUseSemanticRanker] = useState<boolean>(true);
@@ -92,6 +94,8 @@ export function Component(): JSX.Element {
9294
exclude_category: excludeCategory.length === 0 ? undefined : excludeCategory,
9395
top: retrieveCount,
9496
temperature: temperature,
97+
minimum_reranker_score: minimumRerankerScore,
98+
minimum_search_score: minimumSearchScore,
9599
retrieval_mode: retrievalMode,
96100
semantic_ranker: useSemanticRanker,
97101
semantic_captions: useSemanticCaptions,
@@ -134,6 +138,13 @@ export function Component(): JSX.Element {
134138
setTemperature(newValue);
135139
};
136140

141+
const onMinimumSearchScoreChange = (_ev?: React.SyntheticEvent<HTMLElement, Event>, newValue?: string) => {
142+
setMinimumSearchScore(parseFloat(newValue || "0"));
143+
};
144+
145+
const onMinimumRerankerScoreChange = (_ev?: React.SyntheticEvent<HTMLElement, Event>, newValue?: string) => {
146+
setMinimumRerankerScore(parseFloat(newValue || "0"));
147+
};
137148
const onRetrieveCountChange = (_ev?: React.SyntheticEvent<HTMLElement, Event>, newValue?: string) => {
138149
setRetrieveCount(parseInt(newValue || "3"));
139150
};
@@ -259,6 +270,25 @@ export function Component(): JSX.Element {
259270
snapToStep
260271
/>
261272

273+
<SpinButton
274+
className={styles.askSettingsSeparator}
275+
label="Minimum search score"
276+
min={0}
277+
step={0.01}
278+
defaultValue={minimumSearchScore.toString()}
279+
onChange={onMinimumSearchScoreChange}
280+
/>
281+
282+
<SpinButton
283+
className={styles.askSettingsSeparator}
284+
label="Minimum reranker score"
285+
min={1}
286+
max={4}
287+
step={0.1}
288+
defaultValue={minimumRerankerScore.toString()}
289+
onChange={onMinimumRerankerScoreChange}
290+
/>
291+
262292
<SpinButton
263293
className={styles.askSettingsSeparator}
264294
label="Retrieve this many search results:"

app/frontend/src/pages/chat/Chat.module.css

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@
9797
}
9898

9999
.chatSettingsSeparator {
100+
display: flex;
101+
flex-direction: column;
100102
margin-top: 15px;
101103
}
102104

app/frontend/src/pages/chat/Chat.tsx

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ const Chat = () => {
3333
const [isConfigPanelOpen, setIsConfigPanelOpen] = useState(false);
3434
const [promptTemplate, setPromptTemplate] = useState<string>("");
3535
const [temperature, setTemperature] = useState<number>(0.3);
36+
const [minimumRerankerScore, setMinimumRerankerScore] = useState<number>(0);
37+
const [minimumSearchScore, setMinimumSearchScore] = useState<number>(0);
3638
const [retrieveCount, setRetrieveCount] = useState<number>(3);
3739
const [retrievalMode, setRetrievalMode] = useState<RetrievalMode>(RetrievalMode.Hybrid);
3840
const [useSemanticRanker, setUseSemanticRanker] = useState<boolean>(true);
@@ -147,6 +149,8 @@ const Chat = () => {
147149
exclude_category: excludeCategory.length === 0 ? undefined : excludeCategory,
148150
top: retrieveCount,
149151
temperature: temperature,
152+
minimum_reranker_score: minimumRerankerScore,
153+
minimum_search_score: minimumSearchScore,
150154
retrieval_mode: retrievalMode,
151155
semantic_ranker: useSemanticRanker,
152156
semantic_captions: useSemanticCaptions,
@@ -212,6 +216,14 @@ const Chat = () => {
212216
setTemperature(newValue);
213217
};
214218

219+
const onMinimumSearchScoreChange = (_ev?: React.SyntheticEvent<HTMLElement, Event>, newValue?: string) => {
220+
setMinimumSearchScore(parseFloat(newValue || "0"));
221+
};
222+
223+
const onMinimumRerankerScoreChange = (_ev?: React.SyntheticEvent<HTMLElement, Event>, newValue?: string) => {
224+
setMinimumRerankerScore(parseFloat(newValue || "0"));
225+
};
226+
215227
const onRetrieveCountChange = (_ev?: React.SyntheticEvent<HTMLElement, Event>, newValue?: string) => {
216228
setRetrieveCount(parseInt(newValue || "3"));
217229
};
@@ -395,6 +407,25 @@ const Chat = () => {
395407
snapToStep
396408
/>
397409

410+
<SpinButton
411+
className={styles.chatSettingsSeparator}
412+
label="Minimum search score"
413+
min={0}
414+
step={0.01}
415+
defaultValue={minimumSearchScore.toString()}
416+
onChange={onMinimumSearchScoreChange}
417+
/>
418+
419+
<SpinButton
420+
className={styles.chatSettingsSeparator}
421+
label="Minimum reranker score"
422+
min={1}
423+
max={4}
424+
step={0.1}
425+
defaultValue={minimumRerankerScore.toString()}
426+
onChange={onMinimumRerankerScoreChange}
427+
/>
428+
398429
<SpinButton
399430
className={styles.chatSettingsSeparator}
400431
label="Retrieve this many search results:"

0 commit comments

Comments
 (0)