Skip to content

Commit 878c7ab

Browse files
committed
Support GetMemoriesById.
Change-Id: I5cd6cc37dc0e4fcf72306e380b7a971aeb0954f9 BUG: 482877275
1 parent 230a521 commit 878c7ab

File tree

9 files changed

+173
-9
lines changed

9 files changed

+173
-9
lines changed

oak_private_memory/app/handler.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,21 @@ impl SealedMemorySessionHandler {
179179
Ok(GetMemoryByIdResponse { memory, success })
180180
}
181181

182+
pub async fn get_memories_by_id_handler(
183+
&self,
184+
request: GetMemoriesByIdRequest,
185+
) -> tonic::Result<GetMemoriesByIdResponse> {
186+
let mut mutex_guard = self.session_context().await;
187+
let database =
188+
&mut mutex_guard.as_mut().into_failed_precondition("call key sync first")?.database;
189+
190+
let (memories, not_found_ids) = database
191+
.get_memories_by_id(request.ids, &request.result_mask)
192+
.await
193+
.into_internal_error("failed to get memories by id")?;
194+
Ok(GetMemoriesByIdResponse { memories, not_found_ids })
195+
}
196+
182197
pub async fn reset_memory_handler(
183198
&self,
184199
_request: ResetMemoryRequest,
@@ -481,6 +496,9 @@ impl SealedMemorySessionHandler {
481496
sealed_memory_request::Request::DeleteMemoryRequest(request) => {
482497
self.delete_memory_handler(request).await?.into_response()
483498
}
499+
sealed_memory_request::Request::GetMemoriesByIdRequest(request) => {
500+
self.get_memories_by_id_handler(request).await?.into_response()
501+
}
484502
};
485503
let elapsed_time = start_time.elapsed().as_millis() as u64;
486504
self.metrics.record_latency(elapsed_time, metric_name);

oak_private_memory/app/packing.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,5 @@ impl_packing!(Response => GetMemoryByIdResponse);
125125
impl_packing!(Response => SearchMemoryResponse);
126126
impl_packing!(Response => DeleteMemoryResponse);
127127
impl_packing!(Response => UserRegistrationResponse);
128+
impl_packing!(Request => GetMemoriesByIdRequest);
129+
impl_packing!(Response => GetMemoriesByIdResponse);

oak_private_memory/database/database_with_cache.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,31 @@ impl DatabaseWithCache {
166166
}
167167
}
168168

169+
/// Returns memories for the given IDs.
170+
/// Memories can be returned in an arbitrary order.
171+
/// If any memory is not found, it will be skipped and its ID collected in
172+
/// the second return value.
173+
pub async fn get_memories_by_id(
174+
&mut self,
175+
ids: Vec<MemoryId>,
176+
result_mask: &Option<ResultMask>,
177+
) -> anyhow::Result<(Vec<Memory>, Vec<MemoryId>)> {
178+
let mut found_blob_ids: Vec<BlobId> = Vec::new();
179+
let mut not_found_ids: Vec<MemoryId> = Vec::new();
180+
181+
for id in ids {
182+
match self.meta_db().get_blob_id_by_memory_id(id.clone())? {
183+
Some(blob_id) => found_blob_ids.push(blob_id),
184+
None => not_found_ids.push(id),
185+
}
186+
}
187+
188+
let mut memories = self.cache.get_memories_by_blob_ids(&found_blob_ids).await?;
189+
Self::apply_mask_to_memories(&mut memories, result_mask);
190+
191+
Ok((memories, not_found_ids))
192+
}
193+
169194
pub async fn reset_memory(&mut self) -> anyhow::Result<()> {
170195
let all_memory_ids = self.meta_db().get_all_memory_ids()?;
171196
if !all_memory_ids.is_empty() {

oak_private_memory/proto/build.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
4444
"oak.private_memory.KeySyncResponse",
4545
"oak.private_memory.GetMemoryByIdRequest",
4646
"oak.private_memory.GetMemoryByIdResponse",
47+
"oak.private_memory.GetMemoriesByIdRequest",
48+
"oak.private_memory.GetMemoriesByIdResponse",
4749
"oak.private_memory.SearchMemoryRequest",
4850
"oak.private_memory.SearchMemoryResponse",
4951
"oak.private_memory.Embedding",

oak_private_memory/proto/prelude.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ pub mod v1 {
1717
pub use crate::oak::private_memory::{
1818
AddMemoryRequest, AddMemoryResponse, DataBlob, DeleteMemoryRequest, DeleteMemoryResponse,
1919
Embedding, EmbeddingQuery, EmbeddingQueryMetricType, EncryptedDataBlob,
20-
EncryptedMetadataBlob, EncryptedUserInfo, GetMemoriesRequest, GetMemoriesResponse,
21-
GetMemoryByIdRequest, GetMemoryByIdResponse, InvalidRequestResponse, KeyDerivationInfo,
22-
KeySyncRequest, KeySyncResponse, Memory, MemoryContent, MemoryField, MemoryValue,
23-
PlainTextUserInfo, ResetMemoryRequest, ResetMemoryResponse, ResultMask, ScoreRange,
24-
SealedMemoryCredentials, SealedMemoryRequest, SealedMemoryResponse,
25-
SealedMemorySessionRequest, SealedMemorySessionResponse, SearchMemoryQuery,
26-
SearchMemoryRequest, SearchMemoryResponse, SearchMemoryResultItem, UserDb,
27-
UserRegistrationRequest, UserRegistrationResponse, WrappedDataEncryptionKey,
20+
EncryptedMetadataBlob, EncryptedUserInfo, GetMemoriesByIdRequest, GetMemoriesByIdResponse,
21+
GetMemoriesRequest, GetMemoriesResponse, GetMemoryByIdRequest, GetMemoryByIdResponse,
22+
InvalidRequestResponse, KeyDerivationInfo, KeySyncRequest, KeySyncResponse, Memory,
23+
MemoryContent, MemoryField, MemoryValue, PlainTextUserInfo, ResetMemoryRequest,
24+
ResetMemoryResponse, ResultMask, ScoreRange, SealedMemoryCredentials, SealedMemoryRequest,
25+
SealedMemoryResponse, SealedMemorySessionRequest, SealedMemorySessionResponse,
26+
SearchMemoryQuery, SearchMemoryRequest, SearchMemoryResponse, SearchMemoryResultItem,
27+
UserDb, UserRegistrationRequest, UserRegistrationResponse, WrappedDataEncryptionKey,
2828
key_sync_response, memory_value, sealed_memory_request, sealed_memory_response,
2929
search_memory_query, user_registration_response,
3030
};

oak_private_memory/proto/sealed_memory.proto

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,19 @@ message GetMemoryByIdResponse {
223223
Memory memory = 2;
224224
}
225225

226+
message GetMemoriesByIdRequest {
227+
repeated string ids = 1;
228+
ResultMask result_mask = 2;
229+
}
230+
231+
message GetMemoriesByIdResponse {
232+
// Memories can be returned in an arbitrary order.
233+
// If any memory is not found, it will be skipped and the id will be saved
234+
// in the `not_found_ids` field.
235+
repeated Memory memories = 1;
236+
repeated string not_found_ids = 2;
237+
}
238+
226239
// Metric type for comparing embeddings.
227240
enum EmbeddingQueryMetricType {
228241
DOT_PRODUCT = 0;
@@ -405,6 +418,7 @@ message SealedMemoryRequest {
405418
SearchMemoryRequest search_memory_request = 7;
406419
UserRegistrationRequest user_registration_request = 8;
407420
DeleteMemoryRequest delete_memory_request = 9;
421+
GetMemoriesByIdRequest get_memories_by_id_request = 10;
408422
}
409423

410424
// Optional unique identifier for this request within the session.
@@ -424,6 +438,7 @@ message SealedMemoryResponse {
424438
SearchMemoryResponse search_memory_response = 7;
425439
UserRegistrationResponse user_registration_response = 8;
426440
DeleteMemoryResponse delete_memory_response = 9;
441+
GetMemoriesByIdResponse get_memories_by_id_response = 10;
427442

428443
// A non-OK status result for a given request.
429444
//
@@ -439,7 +454,7 @@ message SealedMemoryResponse {
439454
// TODO: b/474398323 - Update clients to check for this variant.
440455
// TODO: b/474398548 - Update Sealed Memory server to populate this variant
441456
// for error conditions.
442-
google.rpc.Status error_status = 100;
457+
google.rpc.Status error = 100;
443458
}
444459

445460
// Propagated from the request_id from the request.

oak_private_memory/src/client.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,4 +295,15 @@ impl PrivateMemoryClient {
295295
self.invoke(sealed_memory_request::Request::ResetMemoryRequest(request)).await?;
296296
expect_response_type!(response, sealed_memory_response::Response::ResetMemoryResponse)
297297
}
298+
299+
pub async fn get_memories_by_id(
300+
&mut self,
301+
ids: Vec<String>,
302+
result_mask: Option<ResultMask>,
303+
) -> Result<GetMemoriesByIdResponse> {
304+
let request = GetMemoriesByIdRequest { ids, result_mask };
305+
let response =
306+
self.invoke(sealed_memory_request::Request::GetMemoriesByIdRequest(request)).await?;
307+
expect_response_type!(response, sealed_memory_response::Response::GetMemoriesByIdResponse)
308+
}
298309
}

oak_private_memory/src/metrics.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ impl RequestMetricName {
394394
sealed_memory_request::Request::GetMemoryByIdRequest(r) => get_name(r),
395395
sealed_memory_request::Request::SearchMemoryRequest(r) => get_name(r),
396396
sealed_memory_request::Request::DeleteMemoryRequest(r) => get_name(r),
397+
sealed_memory_request::Request::GetMemoriesByIdRequest(r) => get_name(r),
397398
}))
398399
}
399400
}

oak_private_memory/test/client_test.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,3 +783,93 @@ async fn search_test_embedding_with_expired_memories() {
783783
assert_eq!(returned_ids.len(), 2);
784784
}
785785
}
786+
787+
#[tokio::test(flavor = "multi_thread")]
788+
async fn test_get_memories_by_id() {
789+
let (addr, _server_join_handle, _db_join_handle, _persistence_join_handle) =
790+
start_server().await.unwrap();
791+
let url = format!("http://{}", addr);
792+
let pm_uid = "test_get_memories_by_id_user";
793+
794+
for &format in [SerializationFormat::BinaryProto, SerializationFormat::Json].iter() {
795+
let mut client =
796+
PrivateMemoryClient::create_with_start_session(&url, pm_uid, TEST_EK, format)
797+
.await
798+
.unwrap();
799+
800+
// Add three memories
801+
let memory1 = Memory {
802+
id: "memory1".to_string(),
803+
tags: vec!["tag1".to_string()],
804+
..Default::default()
805+
};
806+
let memory2 = Memory {
807+
id: "memory2".to_string(),
808+
tags: vec!["tag2".to_string()],
809+
..Default::default()
810+
};
811+
let memory3 = Memory {
812+
id: "memory3".to_string(),
813+
tags: vec!["tag3".to_string()],
814+
..Default::default()
815+
};
816+
817+
client.add_memory(memory1).await.unwrap();
818+
client.add_memory(memory2).await.unwrap();
819+
client.add_memory(memory3).await.unwrap();
820+
821+
// Test fetching multiple memories by ID
822+
let response = client
823+
.get_memories_by_id(
824+
vec!["memory3".to_string(), "memory1".to_string(), "memory2".to_string()],
825+
None,
826+
)
827+
.await
828+
.unwrap();
829+
830+
assert_eq!(response.memories.len(), 3);
831+
assert!(response.not_found_ids.is_empty());
832+
let returned_ids: HashSet<String> =
833+
response.memories.iter().map(|m| m.id.clone()).collect();
834+
assert!(returned_ids.contains("memory1"));
835+
assert!(returned_ids.contains("memory2"));
836+
assert!(returned_ids.contains("memory3"));
837+
838+
// Test fetching a single memory by ID
839+
let response = client.get_memories_by_id(vec!["memory2".to_string()], None).await.unwrap();
840+
assert_eq!(response.memories.len(), 1);
841+
assert_eq!(response.memories[0].id, "memory2");
842+
assert!(response.not_found_ids.is_empty());
843+
844+
// Test fetching with a non-existent ID - should return found ones and report
845+
// not found
846+
let response = client
847+
.get_memories_by_id(
848+
vec![
849+
"memory1".to_string(),
850+
"non_existent_id".to_string(),
851+
"memory3".to_string(),
852+
"another_missing".to_string(),
853+
],
854+
None,
855+
)
856+
.await
857+
.unwrap();
858+
assert_eq!(response.memories.len(), 2);
859+
let returned_ids: HashSet<String> =
860+
response.memories.iter().map(|m| m.id.clone()).collect();
861+
assert!(returned_ids.contains("memory1"));
862+
assert!(returned_ids.contains("memory3"));
863+
assert_eq!(response.not_found_ids.len(), 2);
864+
assert!(response.not_found_ids.contains(&"non_existent_id".to_string()));
865+
assert!(response.not_found_ids.contains(&"another_missing".to_string()));
866+
867+
// Test with all non-existent IDs
868+
let response = client
869+
.get_memories_by_id(vec!["missing1".to_string(), "missing2".to_string()], None)
870+
.await
871+
.unwrap();
872+
assert!(response.memories.is_empty());
873+
assert_eq!(response.not_found_ids.len(), 2);
874+
}
875+
}

0 commit comments

Comments
 (0)