Skip to content

Commit e5ad03f

Browse files
authored
Merge pull request #188 from genomoncology/fix-disease-exact-query-ranking
Fix disease search ranking to surface exact-match canonical hits
2 parents a22c2d2 + 42c01af commit e5ad03f

File tree

2 files changed

+162
-43
lines changed

2 files changed

+162
-43
lines changed

spec/07-disease.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,5 @@ Exact disease labels should be reranked to the front of the returned page even w
7878
```bash
7979
out="$("$(git rev-parse --show-toplevel)/target/release/biomcp" search disease "colorectal cancer" --limit 10)"
8080
echo "$out" | mustmatch like "| ID | Name | Synonyms |"
81-
echo "$out" | mustmatch like "colorectal cancer"
81+
echo "$out" | mustmatch like "| MONDO:0024331 | colorectal carcinoma |"
8282
```

src/entities/disease.rs

Lines changed: 161 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -401,16 +401,20 @@ fn disease_candidate_labels(hit: &MyDiseaseHit) -> Vec<String> {
401401
deduped
402402
}
403403

404+
fn best_disease_candidate_score(query: &str, hit: &MyDiseaseHit) -> i32 {
405+
disease_candidate_labels(hit)
406+
.into_iter()
407+
.map(|label| disease_candidate_score(query, &label))
408+
.max()
409+
.unwrap_or(i32::MIN / 2)
410+
}
411+
404412
fn scored_best_candidate(query: &str, hits: Vec<MyDiseaseHit>) -> Option<MyDiseaseHit> {
405413
let mut ranked: Vec<(i32, usize, String, MyDiseaseHit)> = hits
406414
.into_iter()
407415
.map(|hit| {
408416
let primary_name = transform::disease::name_from_mydisease_hit(&hit);
409-
let best_score = disease_candidate_labels(&hit)
410-
.into_iter()
411-
.map(|label| disease_candidate_score(query, &label))
412-
.max()
413-
.unwrap_or(i32::MIN / 2);
417+
let best_score = best_disease_candidate_score(query, &hit);
414418
let normalized_len = normalize_disease_text(&primary_name).len();
415419
(best_score, normalized_len, hit.id.clone(), hit)
416420
})
@@ -440,6 +444,53 @@ fn resolver_queries(name_or_id: &str) -> Vec<String> {
440444
queries
441445
}
442446

447+
struct DiseaseSearchCandidate {
448+
hit: MyDiseaseHit,
449+
first_seen_query_idx: usize,
450+
first_seen_upstream_idx: usize,
451+
}
452+
453+
fn rerank_disease_search_hits(
454+
query: &str,
455+
query_hits: Vec<(usize, Vec<MyDiseaseHit>)>,
456+
) -> Vec<MyDiseaseHit> {
457+
let mut deduped: HashMap<String, DiseaseSearchCandidate> = HashMap::new();
458+
for (query_idx, hits) in query_hits {
459+
for (upstream_idx, hit) in hits.into_iter().enumerate() {
460+
deduped
461+
.entry(hit.id.clone())
462+
.or_insert(DiseaseSearchCandidate {
463+
hit,
464+
first_seen_query_idx: query_idx,
465+
first_seen_upstream_idx: upstream_idx,
466+
});
467+
}
468+
}
469+
470+
let mut ranked = deduped
471+
.into_values()
472+
.map(|candidate| {
473+
let display_name = transform::disease::name_from_mydisease_hit(&candidate.hit);
474+
(
475+
best_disease_candidate_score(query, &candidate.hit),
476+
disease_exact_rank(&display_name, query),
477+
candidate.first_seen_query_idx,
478+
candidate.first_seen_upstream_idx,
479+
candidate.hit.id.clone(),
480+
candidate.hit,
481+
)
482+
})
483+
.collect::<Vec<_>>();
484+
ranked.sort_by(|a, b| {
485+
b.0.cmp(&a.0)
486+
.then_with(|| b.1.cmp(&a.1))
487+
.then_with(|| a.2.cmp(&b.2))
488+
.then_with(|| a.3.cmp(&b.3))
489+
.then_with(|| a.4.cmp(&b.4))
490+
});
491+
ranked.into_iter().map(|(_, _, _, _, _, hit)| hit).collect()
492+
}
493+
443494
async fn resolve_disease_hit_by_name(
444495
client: &MyDiseaseClient,
445496
name_or_id: &str,
@@ -1338,50 +1389,55 @@ pub async fn search_page(
13381389
} else {
13391390
(needed.saturating_mul(5)).clamp(needed, 50)
13401391
};
1341-
let resp = client
1342-
.query(
1343-
query,
1344-
fetch_size,
1345-
0,
1346-
filters.source.as_deref(),
1347-
inheritance,
1348-
phenotype,
1349-
onset,
1350-
)
1351-
.await?;
13521392
let prefer_doid = filters
13531393
.source
13541394
.as_deref()
13551395
.map(str::trim)
13561396
.is_some_and(|s| s.eq_ignore_ascii_case("doid"));
1357-
let mut ranked = resp
1358-
.hits
1359-
.iter()
1360-
.enumerate()
1361-
.filter(|(_, hit)| {
1362-
inheritance.is_none_or(|value| inheritance_matches(hit, value))
1363-
&& phenotype.is_none_or(|value| phenotype_matches(hit, value))
1364-
&& onset.is_none_or(|value| onset_matches(hit, value))
1365-
})
1366-
.map(|(idx, hit)| {
1367-
let mut row = transform::disease::from_mydisease_search_hit(hit);
1368-
if prefer_doid && let Some(doid) = transform::disease::doid_from_mydisease_hit(hit) {
1397+
1398+
let mut merged_total = 0usize;
1399+
let mut query_hits = Vec::new();
1400+
for (query_idx, resolved_query) in resolver_queries(query).into_iter().enumerate() {
1401+
let resp = client
1402+
.query(
1403+
&resolved_query,
1404+
fetch_size,
1405+
0,
1406+
filters.source.as_deref(),
1407+
inheritance,
1408+
phenotype,
1409+
onset,
1410+
)
1411+
.await?;
1412+
merged_total = merged_total.max(resp.total);
1413+
let hits = resp
1414+
.hits
1415+
.into_iter()
1416+
.filter(|hit| {
1417+
inheritance.is_none_or(|value| inheritance_matches(hit, value))
1418+
&& phenotype.is_none_or(|value| phenotype_matches(hit, value))
1419+
&& onset.is_none_or(|value| onset_matches(hit, value))
1420+
})
1421+
.collect::<Vec<_>>();
1422+
query_hits.push((query_idx, hits));
1423+
}
1424+
1425+
let ranked_hits = rerank_disease_search_hits(query, query_hits);
1426+
let total = Some(merged_total.max(ranked_hits.len()));
1427+
let results = ranked_hits
1428+
.into_iter()
1429+
.skip(offset)
1430+
.take(limit)
1431+
.map(|hit| {
1432+
let mut row = transform::disease::from_mydisease_search_hit(&hit);
1433+
if prefer_doid && let Some(doid) = transform::disease::doid_from_mydisease_hit(&hit) {
13691434
row.id = doid;
13701435
}
1371-
(disease_exact_rank(&row.name, query), idx, row)
1436+
row
13721437
})
13731438
.collect::<Vec<_>>();
1374-
ranked.sort_by(|a, b| b.0.cmp(&a.0).then_with(|| a.1.cmp(&b.1)));
13751439

1376-
Ok(SearchPage::offset(
1377-
ranked
1378-
.into_iter()
1379-
.skip(offset)
1380-
.take(limit)
1381-
.map(|(_, _, row)| row)
1382-
.collect(),
1383-
Some(resp.total),
1384-
))
1440+
Ok(SearchPage::offset(results, total))
13851441
}
13861442

13871443
pub fn search_query_summary(filters: &DiseaseSearchFilters) -> String {
@@ -1628,12 +1684,75 @@ mod tests {
16281684
}
16291685

16301686
#[test]
1631-
fn disease_candidate_score_prefers_broad_match_over_subtype() {
1632-
let broad = disease_candidate_score("breast cancer", "breast carcinoma");
1633-
let subtype = disease_candidate_score("breast cancer", "sporadic breast carcinoma");
1687+
fn disease_candidate_score_prefers_canonical_colorectal_match_over_subtype() {
1688+
let broad = disease_candidate_score("colorectal cancer", "colorectal carcinoma");
1689+
let subtype = disease_candidate_score(
1690+
"colorectal cancer",
1691+
"hereditary nonpolyposis colorectal cancer type 6",
1692+
);
16341693
assert!(broad > subtype);
16351694
}
16361695

1696+
fn test_disease_hit(
1697+
id: &str,
1698+
disease_name: &str,
1699+
mondo_synonyms: &[&str],
1700+
do_synonyms: &[&str],
1701+
) -> MyDiseaseHit {
1702+
serde_json::from_value(serde_json::json!({
1703+
"_id": id,
1704+
"mondo": {
1705+
"name": disease_name,
1706+
"synonym": mondo_synonyms,
1707+
},
1708+
"disease_ontology": {
1709+
"name": disease_name,
1710+
"synonyms": do_synonyms,
1711+
}
1712+
}))
1713+
.expect("valid disease hit")
1714+
}
1715+
1716+
#[test]
1717+
fn rerank_disease_search_hits_prefers_canonical_exact_candidate_across_query_variants() {
1718+
let canonical = test_disease_hit(
1719+
"MONDO:0024331",
1720+
"colorectal carcinoma",
1721+
&["colorectal cancer"],
1722+
&["colorectal cancer"],
1723+
);
1724+
1725+
let ranked = rerank_disease_search_hits(
1726+
"colorectal cancer",
1727+
vec![
1728+
(
1729+
0,
1730+
vec![test_disease_hit(
1731+
"MONDO:0101010",
1732+
"hereditary nonpolyposis colorectal cancer type 6",
1733+
&[],
1734+
&[],
1735+
)],
1736+
),
1737+
(
1738+
1,
1739+
vec![
1740+
canonical,
1741+
test_disease_hit(
1742+
"MONDO:0101010",
1743+
"hereditary nonpolyposis colorectal cancer type 6",
1744+
&[],
1745+
&[],
1746+
),
1747+
],
1748+
),
1749+
],
1750+
);
1751+
1752+
let ids = ranked.iter().map(|hit| hit.id.as_str()).collect::<Vec<_>>();
1753+
assert_eq!(ids, vec!["MONDO:0024331", "MONDO:0101010"]);
1754+
}
1755+
16371756
#[test]
16381757
fn disease_exact_rank_prefers_exact_then_prefix_then_contains() {
16391758
assert!(

0 commit comments

Comments
 (0)