Skip to content

Commit 784fc9e

Browse files
densumeshskeptrunedev
authored andcommitted
feature: add MMR support
1 parent 04c0ba8 commit 784fc9e

File tree

10 files changed

+429
-90
lines changed

10 files changed

+429
-90
lines changed

frontends/search/src/components/ResultsPage.tsx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ const ResultsPage = (props: ResultsPageProps) => {
251251
if (!dataset) return;
252252

253253
let sort_by;
254+
let mmr;
254255

255256
if (isSortBySearchType(props.search.debounced.sort_by)) {
256257
props.search.debounced.sort_by.rerank_type != ""
@@ -262,6 +263,12 @@ const ResultsPage = (props: ResultsPageProps) => {
262263
: (sort_by = undefined);
263264
}
264265

266+
if (!props.search.debounced.mmr.use_mmr) {
267+
mmr = undefined;
268+
} else {
269+
mmr = props.search.debounced.mmr;
270+
}
271+
265272
const query =
266273
props.search.debounced.multiQueries.length > 0
267274
? props.search.debounced.multiQueries
@@ -280,6 +287,7 @@ const ResultsPage = (props: ResultsPageProps) => {
280287
score_threshold: props.search.debounced.scoreThreshold,
281288
sort_options: {
282289
sort_by: sort_by,
290+
mmr: mmr,
283291
},
284292
slim_chunks: props.search.debounced.slimChunks ?? false,
285293
page_size: props.search.debounced.pageSize ?? 10,

frontends/search/src/components/SearchForm.tsx

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,46 @@ const SearchForm = (props: {
11931193
}}
11941194
/>
11951195
</div>
1196+
<div class="flex items-center justify-between space-x-2 p-1">
1197+
<label>Use MMR:</label>
1198+
<input
1199+
class="h-4 w-4"
1200+
type="checkbox"
1201+
checked={tempSearchValues().mmr.use_mmr}
1202+
onChange={(e) => {
1203+
setTempSearchValues((prev) => {
1204+
return {
1205+
...prev,
1206+
mmr: {
1207+
...prev.mmr,
1208+
use_mmr: e.target.checked,
1209+
},
1210+
};
1211+
});
1212+
}}
1213+
/>
1214+
</div>
1215+
<div class="flex items-center justify-between space-x-2 p-1">
1216+
<label>MMR Lambda:</label>
1217+
<input
1218+
class="w-16 rounded border border-neutral-400 p-0.5 text-black"
1219+
type="number"
1220+
value={tempSearchValues().mmr.mmr_lambda}
1221+
onChange={(e) => {
1222+
setTempSearchValues((prev) => {
1223+
return {
1224+
...prev,
1225+
mmr: {
1226+
...prev.mmr,
1227+
mmr_lambda: parseFloat(
1228+
e.currentTarget.value,
1229+
),
1230+
},
1231+
};
1232+
});
1233+
}}
1234+
/>
1235+
</div>
11961236
<div class="px-1 font-bold">Search Refinement</div>
11971237
<div class="flex items-center justify-between space-x-2 p-1">
11981238
<label>Use Quote Negated Words:</label>

frontends/search/src/hooks/useSearch.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ export interface SearchOptions {
6868
prioritize_domain_specifc_words: boolean | null;
6969
disableOnWords: string[];
7070
sort_by: SortByField | SortBySearchType;
71+
mmr: {
72+
use_mmr: boolean;
73+
mmr_lambda?: number;
74+
};
7175
pageSize: number;
7276
getTotalPages: boolean;
7377
highlightResults: boolean;
@@ -98,6 +102,9 @@ const initalState: SearchOptions = {
98102
sort_by: {
99103
field: "",
100104
},
105+
mmr: {
106+
use_mmr: false,
107+
},
101108
pageSize: 10,
102109
getTotalPages: true,
103110
correctTypos: false,
@@ -146,6 +153,7 @@ const fromStateToParams = (state: SearchOptions): Params => {
146153
oneTypoWordRangeMax: state.oneTypoWordRangeMax?.toString() ?? "6",
147154
twoTypoWordRangeMin: state.twoTypoWordRangeMin.toString(),
148155
twoTypoWordRangeMax: state.twoTypoWordRangeMax?.toString() ?? "",
156+
mmr: JSON.stringify(state.mmr),
149157
prioritize_domain_specifc_words:
150158
state.prioritize_domain_specifc_words?.toString() ?? "",
151159
disableOnWords: state.disableOnWords.join(","),
@@ -189,6 +197,11 @@ const fromParamsToState = (
189197
initalState.sort_by,
190198
pageSize: parseInt(params.pageSize ?? "10"),
191199
getTotalPages: (params.getTotalPages ?? "true") === "true",
200+
mmr:
201+
(JSON.parse(params.mmr ?? "{}") as {
202+
use_mmr: boolean;
203+
mmr_lambda?: number;
204+
}) ?? initalState.mmr,
192205
correctTypos: (params.correctTypos ?? "false") === "true",
193206
oneTypoWordRangeMin: parseInt(params.oneTypoWordRangeMin ?? "4"),
194207
oneTypoWordRangeMax: parseIntOrNull(params.oneTypoWordRangeMax),

server/src/data/models.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3520,6 +3520,7 @@ impl ApiKeyRequestParams {
35203520
new_message_content: payload.new_message_content,
35213521
topic_id: payload.topic_id,
35223522
user_id: payload.user_id,
3523+
sort_options: payload.sort_options,
35233524
highlight_options: self.highlight_options.or(payload.highlight_options),
35243525
search_type: self.search_type.or(payload.search_type),
35253526
use_group_search: payload.use_group_search,
@@ -6670,6 +6671,17 @@ pub struct SortOptions {
66706671
pub use_weights: Option<bool>,
66716672
/// Tag weights is a JSON object which can be used to boost the ranking of chunks with certain tags. This is useful for when you want to be able to bias towards chunks with a certain tag on the fly. The keys are the tag names and the values are the weights.
66726673
pub tag_weights: Option<HashMap<String, f32>>,
6674+
/// Set use_mmr to true to use the Maximal Marginal Relevance algorithm to rerank the results. If not specified, this defaults to false.
6675+
pub mmr: Option<MmrOptions>,
6676+
}
6677+
6678+
#[derive(Serialize, Deserialize, Debug, Clone, ToSchema, Default)]
6679+
/// MMR Options lets you specify different methods to rerank the chunks in the result set. If not specified, this defaults to the score of the chunks.
6680+
pub struct MmrOptions {
6681+
/// Set use_mmr to true to use the Maximal Marginal Relevance algorithm to rerank the results.
6682+
pub use_mmr: bool,
6683+
/// Set mmr_lambda to a value between 0.0 and 1.0 to control the tradeoff between relevance and diversity. Closer to 1.0 will give more diverse results, closer to 0.0 will give more relevant results. If not specified, this defaults to 0.5.
6684+
pub mmr_lambda: Option<f32>,
66736685
}
66746686

66756687
#[derive(Serialize, Deserialize, Debug, Clone, ToSchema, Default)]
@@ -6790,6 +6802,9 @@ fn extract_sort_highlight_options(
67906802
if let Some(value) = other.remove("tag_weights") {
67916803
sort_options.tag_weights = serde_json::from_value(value).ok();
67926804
}
6805+
if let Some(value) = other.remove("mmr") {
6806+
sort_options.mmr = serde_json::from_value(value).ok();
6807+
}
67936808

67946809
// Extract highlight options
67956810
if let Some(value) = other.remove("highlight_results") {
@@ -6818,6 +6833,7 @@ fn extract_sort_highlight_options(
68186833
&& sort_options.location_bias.is_none()
68196834
&& sort_options.use_weights.is_none()
68206835
&& sort_options.tag_weights.is_none()
6836+
&& sort_options.mmr.is_none()
68216837
{
68226838
None
68236839
} else {
@@ -7143,6 +7159,7 @@ impl<'de> Deserialize<'de> for CreateMessageReqPayload {
71437159
pub search_type: Option<SearchMethod>,
71447160
pub concat_user_messages_query: Option<bool>,
71457161
pub search_query: Option<String>,
7162+
pub sort_options: Option<SortOptions>,
71467163
pub page_size: Option<u64>,
71477164
pub filters: Option<ChunkFilter>,
71487165
pub score_threshold: Option<f32>,
@@ -7172,6 +7189,7 @@ impl<'de> Deserialize<'de> for CreateMessageReqPayload {
71727189
new_message_content: helper.new_message_content,
71737190
topic_id: helper.topic_id,
71747191
highlight_options,
7192+
sort_options: helper.sort_options,
71757193
search_type: helper.search_type,
71767194
use_group_search: helper.use_group_search,
71777195
concat_user_messages_query: helper.concat_user_messages_query,
@@ -7198,6 +7216,8 @@ impl<'de> Deserialize<'de> for RegenerateMessageReqPayload {
71987216
pub highlight_options: Option<HighlightOptions>,
71997217
pub search_type: Option<SearchMethod>,
72007218
pub concat_user_messages_query: Option<bool>,
7219+
pub sort_options: Option<SortOptions>,
7220+
72017221
pub search_query: Option<String>,
72027222
pub page_size: Option<u64>,
72037223
pub filters: Option<ChunkFilter>,
@@ -7227,6 +7247,7 @@ impl<'de> Deserialize<'de> for RegenerateMessageReqPayload {
72277247
Ok(RegenerateMessageReqPayload {
72287248
topic_id: helper.topic_id,
72297249
highlight_options,
7250+
sort_options: helper.sort_options,
72307251
search_type: helper.search_type,
72317252
concat_user_messages_query: helper.concat_user_messages_query,
72327253
search_query: helper.search_query,
@@ -7254,6 +7275,8 @@ impl<'de> Deserialize<'de> for EditMessageReqPayload {
72547275
pub new_message_content: String,
72557276
pub highlight_options: Option<HighlightOptions>,
72567277
pub search_type: Option<SearchMethod>,
7278+
pub sort_options: Option<SortOptions>,
7279+
72577280
pub use_group_search: Option<bool>,
72587281
pub concat_user_messages_query: Option<bool>,
72597282
pub search_query: Option<String>,
@@ -7284,6 +7307,7 @@ impl<'de> Deserialize<'de> for EditMessageReqPayload {
72847307
Ok(EditMessageReqPayload {
72857308
topic_id: helper.topic_id,
72867309
message_sort_order: helper.message_sort_order,
7310+
sort_options: helper.sort_options,
72877311
new_message_content: helper.new_message_content,
72887312
highlight_options,
72897313
search_type: helper.search_type,

server/src/handlers/message_handler.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::{
66
data::models::{
77
self, ChunkMetadata, ChunkMetadataStringTagSet, ChunkMetadataTypes, ContextOptions,
88
DatasetAndOrgWithSubAndPlan, DatasetConfiguration, HighlightOptions, LLMOptions, Pool,
9-
QdrantChunkMetadata, RedisPool, SearchMethod, SuggestType,
9+
QdrantChunkMetadata, RedisPool, SearchMethod, SortOptions, SuggestType,
1010
},
1111
errors::ServiceError,
1212
get_env,
@@ -98,6 +98,8 @@ pub struct CreateMessageReqPayload {
9898
pub search_query: Option<String>,
9999
/// Page size is the number of chunks to fetch during RAG. If 0, then no search will be performed. If specified, this will override the N retrievals to include in the dataset configuration. Default is None.
100100
pub page_size: Option<u64>,
101+
/// Sort Options lets you specify different methods to rerank the chunks in the result set. If not specified, this defaults to the score of the chunks.
102+
pub sort_options: Option<SortOptions>,
101103
/// Filters is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata.
102104
pub filters: Option<ChunkFilter>,
103105
/// Set score_threshold to a float to filter out chunks with a score below the threshold. This threshold applies before weight and bias modifications. If not specified, this defaults to 0.0.
@@ -349,6 +351,8 @@ pub struct RegenerateMessageReqPayload {
349351
pub search_query: Option<String>,
350352
/// Page size is the number of chunks to fetch during RAG. If 0, then no search will be performed. If specified, this will override the N retrievals to include in the dataset configuration. Default is None.
351353
pub page_size: Option<u64>,
354+
/// Sort Options lets you specify different methods to rerank the chunks in the result set. If not specified, this defaults to the score of the chunks.
355+
pub sort_options: Option<SortOptions>,
352356
/// Filters is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata.
353357
pub filters: Option<ChunkFilter>,
354358
/// Set score_threshold to a float to filter out chunks with a score below the threshold. This threshold applies before weight and bias modifications. If not specified, this defaults to 0.0.
@@ -381,6 +385,8 @@ pub struct EditMessageReqPayload {
381385
pub concat_user_messages_query: Option<bool>,
382386
/// Query is the search query. This can be any string. The search_query will be used to create a dense embedding vector and/or sparse vector which will be used to find the result set. If not specified, will default to the last user message or HyDE if HyDE is enabled in the dataset configuration. Default is None.
383387
pub search_query: Option<String>,
388+
/// Sort Options lets you specify different methods to rerank the chunks in the result set. If not specified, this defaults to the score of the chunks.
389+
pub sort_options: Option<SortOptions>,
384390
/// Page size is the number of chunks to fetch during RAG. If 0, then no search will be performed. If specified, this will override the N retrievals to include in the dataset configuration. Default is None.
385391
pub page_size: Option<u64>,
386392
/// Filters is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata.
@@ -404,6 +410,7 @@ impl From<EditMessageReqPayload> for CreateMessageReqPayload {
404410
topic_id: data.topic_id,
405411
highlight_options: data.highlight_options,
406412
search_type: data.search_type,
413+
sort_options: data.sort_options,
407414
use_group_search: data.use_group_search,
408415
concat_user_messages_query: data.concat_user_messages_query,
409416
search_query: data.search_query,
@@ -426,6 +433,7 @@ impl From<RegenerateMessageReqPayload> for CreateMessageReqPayload {
426433
highlight_options: data.highlight_options,
427434
search_type: data.search_type,
428435
use_group_search: data.use_group_search,
436+
sort_options: data.sort_options,
429437
concat_user_messages_query: data.concat_user_messages_query,
430438
search_query: data.search_query,
431439
page_size: data.page_size,

server/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,7 @@ impl Modify for SecurityAddon {
495495
data::models::OrganizationUsageCount,
496496
data::models::Dataset,
497497
data::models::DatasetAndUsage,
498+
data::models::MmrOptions,
498499
data::models::DatasetUsageCount,
499500
data::models::DatasetDTO,
500501
data::models::DatasetUsageCount,

server/src/operators/chunk_operator.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ pub struct ChunkMetadataWithQdrantId {
138138
pub qdrant_id: uuid::Uuid,
139139
}
140140

141-
pub async fn get_chunk_metadatas_and_collided_chunks_from_point_ids_query(
141+
pub async fn get_chunk_metadatas_from_point_ids_query(
142142
point_ids: Vec<uuid::Uuid>,
143143
pool: web::Data<Pool>,
144144
) -> Result<Vec<ChunkMetadataTypes>, ServiceError> {

server/src/operators/message_operator.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ pub async fn get_rag_chunks_query(
348348
.page_size
349349
.unwrap_or(n_retrievals_to_include.try_into().unwrap_or(8)),
350350
),
351+
sort_options: create_message_req_payload.sort_options,
351352
highlight_options: create_message_req_payload.highlight_options,
352353
filters: create_message_req_payload.filters,
353354
group_size: Some(1),
@@ -453,6 +454,7 @@ pub async fn get_rag_chunks_query(
453454
search_type: search_type.clone(),
454455
query: QueryTypes::Single(query.clone()),
455456
score_threshold: create_message_req_payload.score_threshold,
457+
sort_options: create_message_req_payload.sort_options,
456458
page_size: Some(
457459
create_message_req_payload
458460
.page_size

0 commit comments

Comments
 (0)