@@ -783,3 +783,93 @@ async fn search_test_embedding_with_expired_memories() {
783783 assert_eq ! ( returned_ids. len( ) , 2 ) ;
784784 }
785785}
786+
787+ #[ tokio:: test( flavor = "multi_thread" ) ]
788+ async fn test_get_memories_by_id ( ) {
789+ let ( addr, _server_join_handle, _db_join_handle, _persistence_join_handle) =
790+ start_server ( ) . await . unwrap ( ) ;
791+ let url = format ! ( "http://{}" , addr) ;
792+ let pm_uid = "test_get_memories_by_id_user" ;
793+
794+ for & format in [ SerializationFormat :: BinaryProto , SerializationFormat :: Json ] . iter ( ) {
795+ let mut client =
796+ PrivateMemoryClient :: create_with_start_session ( & url, pm_uid, TEST_EK , format)
797+ . await
798+ . unwrap ( ) ;
799+
800+ // Add three memories
801+ let memory1 = Memory {
802+ id : "memory1" . to_string ( ) ,
803+ tags : vec ! [ "tag1" . to_string( ) ] ,
804+ ..Default :: default ( )
805+ } ;
806+ let memory2 = Memory {
807+ id : "memory2" . to_string ( ) ,
808+ tags : vec ! [ "tag2" . to_string( ) ] ,
809+ ..Default :: default ( )
810+ } ;
811+ let memory3 = Memory {
812+ id : "memory3" . to_string ( ) ,
813+ tags : vec ! [ "tag3" . to_string( ) ] ,
814+ ..Default :: default ( )
815+ } ;
816+
817+ client. add_memory ( memory1) . await . unwrap ( ) ;
818+ client. add_memory ( memory2) . await . unwrap ( ) ;
819+ client. add_memory ( memory3) . await . unwrap ( ) ;
820+
821+ // Test fetching multiple memories by ID
822+ let response = client
823+ . get_memories_by_id (
824+ vec ! [ "memory3" . to_string( ) , "memory1" . to_string( ) , "memory2" . to_string( ) ] ,
825+ None ,
826+ )
827+ . await
828+ . unwrap ( ) ;
829+
830+ assert_eq ! ( response. memories. len( ) , 3 ) ;
831+ assert ! ( response. not_found_ids. is_empty( ) ) ;
832+ let returned_ids: HashSet < String > =
833+ response. memories . iter ( ) . map ( |m| m. id . clone ( ) ) . collect ( ) ;
834+ assert ! ( returned_ids. contains( "memory1" ) ) ;
835+ assert ! ( returned_ids. contains( "memory2" ) ) ;
836+ assert ! ( returned_ids. contains( "memory3" ) ) ;
837+
838+ // Test fetching a single memory by ID
839+ let response = client. get_memories_by_id ( vec ! [ "memory2" . to_string( ) ] , None ) . await . unwrap ( ) ;
840+ assert_eq ! ( response. memories. len( ) , 1 ) ;
841+ assert_eq ! ( response. memories[ 0 ] . id, "memory2" ) ;
842+ assert ! ( response. not_found_ids. is_empty( ) ) ;
843+
844+ // Test fetching with a non-existent ID - should return found ones and report
845+ // not found
846+ let response = client
847+ . get_memories_by_id (
848+ vec ! [
849+ "memory1" . to_string( ) ,
850+ "non_existent_id" . to_string( ) ,
851+ "memory3" . to_string( ) ,
852+ "another_missing" . to_string( ) ,
853+ ] ,
854+ None ,
855+ )
856+ . await
857+ . unwrap ( ) ;
858+ assert_eq ! ( response. memories. len( ) , 2 ) ;
859+ let returned_ids: HashSet < String > =
860+ response. memories . iter ( ) . map ( |m| m. id . clone ( ) ) . collect ( ) ;
861+ assert ! ( returned_ids. contains( "memory1" ) ) ;
862+ assert ! ( returned_ids. contains( "memory3" ) ) ;
863+ assert_eq ! ( response. not_found_ids. len( ) , 2 ) ;
864+ assert ! ( response. not_found_ids. contains( & "non_existent_id" . to_string( ) ) ) ;
865+ assert ! ( response. not_found_ids. contains( & "another_missing" . to_string( ) ) ) ;
866+
867+ // Test with all non-existent IDs
868+ let response = client
869+ . get_memories_by_id ( vec ! [ "missing1" . to_string( ) , "missing2" . to_string( ) ] , None )
870+ . await
871+ . unwrap ( ) ;
872+ assert ! ( response. memories. is_empty( ) ) ;
873+ assert_eq ! ( response. not_found_ids. len( ) , 2 ) ;
874+ }
875+ }
0 commit comments