Skip to content

Commit c3f10ba

Browse files
committed
Don't sort by timestamp when embedding search is involved
Change-Id: Ie3555315e8fdea70beb54793695c395ed29f0d09
1 parent 122616e commit c3f10ba

File tree

3 files changed

+311
-6
lines changed

3 files changed

+311
-6
lines changed

oak_private_memory/database/icing.rs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,15 +1064,15 @@ impl IcingMetaDatabase {
10641064
_ => bail!("unsupported operator"),
10651065
};
10661066

1067+
// Check upfront if any clause involves embedding search.
1068+
// If so, use embedding-based scoring; otherwise use text-based scoring.
1069+
let has_embedding_search = clauses.clauses.iter().any(Self::is_embedding_search);
1070+
10671071
let mut sub_queries = Vec::new();
1068-
let mut score_spec = icing::ScoringSpecProto::default();
10691072
let mut embedding_vectors = Vec::new();
10701073
let mut metric_type = None;
10711074
for clause in &clauses.clauses {
1072-
let (spec, sub_score_spec) = self.build_query_specs(clause, schema_name)?;
1073-
if let Some(sub_score_spec) = sub_score_spec {
1074-
score_spec = sub_score_spec;
1075-
}
1075+
let (spec, _) = self.build_query_specs(clause, schema_name)?;
10761076
if spec.embedding_query_metric_type.is_some() {
10771077
metric_type = spec.embedding_query_metric_type;
10781078
}
@@ -1095,7 +1095,21 @@ impl IcingMetaDatabase {
10951095
if !schema_name.is_empty() {
10961096
search_spec.schema_type_filters.push(schema_name.to_string());
10971097
}
1098-
Ok((search_spec, Some(score_spec)))
1098+
1099+
// Use embedding scoring if any clause involves embedding search,
1100+
// otherwise use text-based CreationTimestamp scoring.
1101+
let score_spec = if has_embedding_search {
1102+
Some(self.build_scoring_spec())
1103+
} else {
1104+
Some(icing::ScoringSpecProto {
1105+
rank_by: Some(
1106+
icing::scoring_spec_proto::ranking_strategy::Code::CreationTimestamp.into(),
1107+
),
1108+
..Default::default()
1109+
})
1110+
};
1111+
1112+
Ok((search_spec, score_spec))
10991113
}
11001114

11011115
fn build_text_query_specs(

oak_private_memory/proto/sealed_memory.proto

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,14 @@ enum QueryOperator {
295295
QUERY_OPERATOR_OR = 2; // Any clause can match.
296296
}
297297

298+
// Query for searching memories. Supports text-based queries, embedding-based
299+
// (semantic) queries, and combinations using QueryClauses.
300+
//
301+
// Sorting behavior:
302+
// - If the query involves any embedding search (at any nesting level), results
303+
// are sorted by embedding similarity (highest score first).
304+
// - If the query involves only text-based queries, results are sorted by
305+
// creation timestamp (newest first).
298306
message SearchMemoryQuery {
299307
oneof clause {
300308
EmbeddingQuery embedding_query = 1;

oak_private_memory/test/search_test.rs

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,289 @@ fn test_hybrid_search_with_timestamp() -> anyhow::Result<()> {
160160
Ok(())
161161
}
162162

163+
// Regression test: when embedding search is involved in hybrid queries,
164+
// both clause orderings should sort by embedding similarity:
165+
// - tag AND embedding -> sorts by embedding
166+
// - embedding AND tag -> sorts by embedding
167+
// Previously, the last clause's scoring spec would win, causing
168+
// `embedding AND tag` to incorrectly sort by CreationTimestamp.
169+
#[gtest]
170+
fn test_hybrid_search_clause_order_does_not_affect_ranking() -> anyhow::Result<()> {
171+
let mut icing_database = IcingMetaDatabase::new(IcingTempDir::new("hybrid-order-test"))?;
172+
173+
// Add memories with embeddings that have different similarity scores to the
174+
// query. memory1 has higher similarity (1.0+2.0+3.0=6.0) than memory2
175+
// (0.1+0.2+0.3=0.6).
176+
let memory1 = Memory {
177+
id: "memory1".to_string(),
178+
views: Some(LlmViews {
179+
llm_views: vec![LlmView {
180+
id: "view1".to_string(),
181+
embedding: Some(Embedding {
182+
model_signature: "test_model".to_string(),
183+
values: vec![1.0, 2.0, 3.0],
184+
}),
185+
..Default::default()
186+
}],
187+
}),
188+
tags: vec!["test_tag".to_string()],
189+
..Default::default()
190+
};
191+
icing_database.add_memory(&memory1, "blob1".to_string())?;
192+
193+
let memory2 = Memory {
194+
id: "memory2".to_string(),
195+
views: Some(LlmViews {
196+
llm_views: vec![LlmView {
197+
id: "view2".to_string(),
198+
embedding: Some(Embedding {
199+
model_signature: "test_model".to_string(),
200+
values: vec![0.1, 0.2, 0.3],
201+
}),
202+
..Default::default()
203+
}],
204+
}),
205+
tags: vec!["test_tag".to_string()],
206+
..Default::default()
207+
};
208+
icing_database.add_memory(&memory2, "blob2".to_string())?;
209+
210+
let embedding_query = SearchMemoryQuery {
211+
clause: Some(search_memory_query::Clause::EmbeddingQuery(EmbeddingQuery {
212+
embedding: vec![Embedding {
213+
model_signature: "test_model".to_string(),
214+
values: vec![1.0, 1.0, 1.0],
215+
}],
216+
..Default::default()
217+
})),
218+
};
219+
220+
let tag_query = SearchMemoryQuery {
221+
clause: Some(search_memory_query::Clause::TextQuery(TextQuery {
222+
match_type: MatchType::Equal as i32,
223+
field: MemoryField::Tags as i32,
224+
value: Some(text_query::Value::StringVal("test_tag".to_string())),
225+
})),
226+
};
227+
228+
// Test order 1: embedding AND tag
229+
let query_embedding_first = SearchMemoryQuery {
230+
clause: Some(search_memory_query::Clause::QueryClauses(QueryClauses {
231+
query_operator: QueryOperator::And as i32,
232+
clauses: vec![embedding_query.clone(), tag_query.clone()],
233+
})),
234+
};
235+
236+
// Test order 2: tag AND embedding (this was buggy before the fix)
237+
let query_tag_first = SearchMemoryQuery {
238+
clause: Some(search_memory_query::Clause::QueryClauses(QueryClauses {
239+
query_operator: QueryOperator::And as i32,
240+
clauses: vec![tag_query, embedding_query],
241+
})),
242+
};
243+
244+
let (results1, _) = icing_database.search(&query_embedding_first, 10, PageToken::Start)?;
245+
let (results2, _) = icing_database.search(&query_tag_first, 10, PageToken::Start)?;
246+
247+
// Both should return the same results in the same order (sorted by embedding
248+
// similarity).
249+
let blob_ids1: Vec<String> = results1.items.iter().map(|r| r.blob_id.clone()).collect();
250+
let blob_ids2: Vec<String> = results2.items.iter().map(|r| r.blob_id.clone()).collect();
251+
252+
// memory1 (blob1) has higher embedding similarity, so it should come first.
253+
assert_that!(blob_ids1, elements_are![eq(&"blob1"), eq(&"blob2")]);
254+
assert_that!(blob_ids2, elements_are![eq(&"blob1"), eq(&"blob2")]);
255+
256+
// Verify scores are the same regardless of order.
257+
let scores1: Vec<f32> = results1.items.iter().map(|r| r.score).collect();
258+
let scores2: Vec<f32> = results2.items.iter().map(|r| r.score).collect();
259+
assert_that!(scores1, eq(&scores2));
260+
261+
Ok(())
262+
}
263+
264+
// Test nested clauses: { TAG AND { TAG AND EMBEDDING } }
265+
// Embedding is in a nested inner clause; should still sort by embedding.
266+
#[gtest]
267+
fn test_hybrid_search_nested_embedding_inner() -> anyhow::Result<()> {
268+
let mut icing_database = IcingMetaDatabase::new(IcingTempDir::new("nested-inner-test"))?;
269+
270+
// memory1 has higher embedding similarity (1.0+2.0+3.0=6.0)
271+
let memory1 = Memory {
272+
id: "memory1".to_string(),
273+
views: Some(LlmViews {
274+
llm_views: vec![LlmView {
275+
id: "view1".to_string(),
276+
embedding: Some(Embedding {
277+
model_signature: "test_model".to_string(),
278+
values: vec![1.0, 2.0, 3.0],
279+
}),
280+
..Default::default()
281+
}],
282+
}),
283+
tags: vec!["tag1".to_string(), "tag2".to_string()],
284+
..Default::default()
285+
};
286+
icing_database.add_memory(&memory1, "blob1".to_string())?;
287+
288+
// memory2 has lower embedding similarity (0.1+0.2+0.3=0.6)
289+
let memory2 = Memory {
290+
id: "memory2".to_string(),
291+
views: Some(LlmViews {
292+
llm_views: vec![LlmView {
293+
id: "view2".to_string(),
294+
embedding: Some(Embedding {
295+
model_signature: "test_model".to_string(),
296+
values: vec![0.1, 0.2, 0.3],
297+
}),
298+
..Default::default()
299+
}],
300+
}),
301+
tags: vec!["tag1".to_string(), "tag2".to_string()],
302+
..Default::default()
303+
};
304+
icing_database.add_memory(&memory2, "blob2".to_string())?;
305+
306+
let embedding_query = SearchMemoryQuery {
307+
clause: Some(search_memory_query::Clause::EmbeddingQuery(EmbeddingQuery {
308+
embedding: vec![Embedding {
309+
model_signature: "test_model".to_string(),
310+
values: vec![1.0, 1.0, 1.0],
311+
}],
312+
..Default::default()
313+
})),
314+
};
315+
316+
let tag1_query = SearchMemoryQuery {
317+
clause: Some(search_memory_query::Clause::TextQuery(TextQuery {
318+
match_type: MatchType::Equal as i32,
319+
field: MemoryField::Tags as i32,
320+
value: Some(text_query::Value::StringVal("tag1".to_string())),
321+
})),
322+
};
323+
324+
let tag2_query = SearchMemoryQuery {
325+
clause: Some(search_memory_query::Clause::TextQuery(TextQuery {
326+
match_type: MatchType::Equal as i32,
327+
field: MemoryField::Tags as i32,
328+
value: Some(text_query::Value::StringVal("tag2".to_string())),
329+
})),
330+
};
331+
332+
// Build: { TAG1 AND { TAG2 AND EMBEDDING } }
333+
let inner_clause = SearchMemoryQuery {
334+
clause: Some(search_memory_query::Clause::QueryClauses(QueryClauses {
335+
query_operator: QueryOperator::And as i32,
336+
clauses: vec![tag2_query, embedding_query],
337+
})),
338+
};
339+
let outer_query = SearchMemoryQuery {
340+
clause: Some(search_memory_query::Clause::QueryClauses(QueryClauses {
341+
query_operator: QueryOperator::And as i32,
342+
clauses: vec![tag1_query, inner_clause],
343+
})),
344+
};
345+
346+
let (results, _) = icing_database.search(&outer_query, 10, PageToken::Start)?;
347+
let blob_ids: Vec<String> = results.items.iter().map(|r| r.blob_id.clone()).collect();
348+
349+
// Should be sorted by embedding similarity (blob1 first).
350+
assert_that!(blob_ids, elements_are![eq(&"blob1"), eq(&"blob2")]);
351+
352+
Ok(())
353+
}
354+
355+
// Test nested clauses: { { TAG AND EMBEDDING } AND TAG }
356+
// Embedding is in a nested inner clause with outer tag after.
357+
#[gtest]
358+
fn test_hybrid_search_nested_embedding_outer_tag() -> anyhow::Result<()> {
359+
let mut icing_database = IcingMetaDatabase::new(IcingTempDir::new("nested-outer-test"))?;
360+
361+
// memory1 has higher embedding similarity
362+
let memory1 = Memory {
363+
id: "memory1".to_string(),
364+
views: Some(LlmViews {
365+
llm_views: vec![LlmView {
366+
id: "view1".to_string(),
367+
embedding: Some(Embedding {
368+
model_signature: "test_model".to_string(),
369+
values: vec![1.0, 2.0, 3.0],
370+
}),
371+
..Default::default()
372+
}],
373+
}),
374+
tags: vec!["tag1".to_string(), "tag2".to_string()],
375+
..Default::default()
376+
};
377+
icing_database.add_memory(&memory1, "blob1".to_string())?;
378+
379+
// memory2 has lower embedding similarity
380+
let memory2 = Memory {
381+
id: "memory2".to_string(),
382+
views: Some(LlmViews {
383+
llm_views: vec![LlmView {
384+
id: "view2".to_string(),
385+
embedding: Some(Embedding {
386+
model_signature: "test_model".to_string(),
387+
values: vec![0.1, 0.2, 0.3],
388+
}),
389+
..Default::default()
390+
}],
391+
}),
392+
tags: vec!["tag1".to_string(), "tag2".to_string()],
393+
..Default::default()
394+
};
395+
icing_database.add_memory(&memory2, "blob2".to_string())?;
396+
397+
let embedding_query = SearchMemoryQuery {
398+
clause: Some(search_memory_query::Clause::EmbeddingQuery(EmbeddingQuery {
399+
embedding: vec![Embedding {
400+
model_signature: "test_model".to_string(),
401+
values: vec![1.0, 1.0, 1.0],
402+
}],
403+
..Default::default()
404+
})),
405+
};
406+
407+
let tag1_query = SearchMemoryQuery {
408+
clause: Some(search_memory_query::Clause::TextQuery(TextQuery {
409+
match_type: MatchType::Equal as i32,
410+
field: MemoryField::Tags as i32,
411+
value: Some(text_query::Value::StringVal("tag1".to_string())),
412+
})),
413+
};
414+
415+
let tag2_query = SearchMemoryQuery {
416+
clause: Some(search_memory_query::Clause::TextQuery(TextQuery {
417+
match_type: MatchType::Equal as i32,
418+
field: MemoryField::Tags as i32,
419+
value: Some(text_query::Value::StringVal("tag2".to_string())),
420+
})),
421+
};
422+
423+
// Build: { { TAG1 AND EMBEDDING } AND TAG2 }
424+
let inner_clause = SearchMemoryQuery {
425+
clause: Some(search_memory_query::Clause::QueryClauses(QueryClauses {
426+
query_operator: QueryOperator::And as i32,
427+
clauses: vec![tag1_query, embedding_query],
428+
})),
429+
};
430+
let outer_query = SearchMemoryQuery {
431+
clause: Some(search_memory_query::Clause::QueryClauses(QueryClauses {
432+
query_operator: QueryOperator::And as i32,
433+
clauses: vec![inner_clause, tag2_query],
434+
})),
435+
};
436+
437+
let (results, _) = icing_database.search(&outer_query, 10, PageToken::Start)?;
438+
let blob_ids: Vec<String> = results.items.iter().map(|r| r.blob_id.clone()).collect();
439+
440+
// Should be sorted by embedding similarity (blob1 first).
441+
assert_that!(blob_ids, elements_are![eq(&"blob1"), eq(&"blob2")]);
442+
443+
Ok(())
444+
}
445+
163446
#[gtest]
164447
fn test_search_views() -> anyhow::Result<()> {
165448
let mut icing_database = IcingMetaDatabase::new(IcingTempDir::new("embedding-search-test"))?;

0 commit comments

Comments
 (0)