1313// limitations under the License.
1414
1515use std:: {
16- collections:: { hash_map:: Entry , HashMap , HashSet } ,
16+ collections:: { hash_map:: Entry , BTreeMap , HashMap , HashSet } ,
1717 convert:: Infallible ,
1818 sync:: { Arc , RwLock as StdRwLock } ,
1919 time:: { Duration , Instant } ,
@@ -54,6 +54,7 @@ pub struct MemoryStore {
5454 account : StdRwLock < Option < Account > > ,
5555 sessions : SessionStore ,
5656 inbound_group_sessions : GroupSessionStore ,
57+ outbound_group_sessions : StdRwLock < BTreeMap < OwnedRoomId , OutboundGroupSession > > ,
5758 olm_hashes : StdRwLock < HashMap < String , HashSet < String > > > ,
5859 devices : DeviceStore ,
5960 identities : StdRwLock < HashMap < OwnedUserId , ReadOnlyUserIdentities > > ,
@@ -74,6 +75,7 @@ impl Default for MemoryStore {
7475 account : Default :: default ( ) ,
7576 sessions : SessionStore :: new ( ) ,
7677 inbound_group_sessions : GroupSessionStore :: new ( ) ,
78+ outbound_group_sessions : Default :: default ( ) ,
7779 olm_hashes : Default :: default ( ) ,
7880 devices : DeviceStore :: new ( ) ,
7981 identities : Default :: default ( ) ,
@@ -119,6 +121,13 @@ impl MemoryStore {
119121 self . inbound_group_sessions . add ( session) ;
120122 }
121123 }
124+
125+ fn save_outbound_group_sessions ( & self , sessions : Vec < OutboundGroupSession > ) {
126+ self . outbound_group_sessions
127+ . write ( )
128+ . unwrap ( )
129+ . extend ( sessions. into_iter ( ) . map ( |s| ( s. room_id ( ) . to_owned ( ) , s) ) ) ;
130+ }
122131}
123132
124133type Result < T > = std:: result:: Result < T , Infallible > ;
@@ -151,6 +160,7 @@ impl CryptoStore for MemoryStore {
151160 async fn save_changes ( & self , changes : Changes ) -> Result < ( ) > {
152161 self . save_sessions ( changes. sessions ) . await ;
153162 self . save_inbound_group_sessions ( changes. inbound_group_sessions ) ;
163+ self . save_outbound_group_sessions ( changes. outbound_group_sessions ) ;
154164
155165 self . save_devices ( changes. devices . new ) ;
156166 self . save_devices ( changes. devices . changed ) ;
@@ -297,8 +307,11 @@ impl CryptoStore for MemoryStore {
297307 Ok ( self . backup_keys . read ( ) . await . to_owned ( ) )
298308 }
299309
300- async fn get_outbound_group_session ( & self , _: & RoomId ) -> Result < Option < OutboundGroupSession > > {
301- Ok ( None )
310+ async fn get_outbound_group_session (
311+ & self ,
312+ room_id : & RoomId ,
313+ ) -> Result < Option < OutboundGroupSession > > {
314+ Ok ( self . outbound_group_sessions . read ( ) . unwrap ( ) . get ( room_id) . cloned ( ) )
302315 }
303316
304317 async fn load_tracked_users ( & self ) -> Result < Vec < TrackedUser > > {
@@ -487,7 +500,7 @@ mod tests {
487500 }
488501
489502 #[ async_test]
490- async fn test_group_session_store ( ) {
503+ async fn test_inbound_group_session_store ( ) {
491504 let ( account, _) = get_account_and_session_test_helper ( ) ;
492505 let room_id = room_id ! ( "!test:localhost" ) ;
493506 let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw" ;
@@ -511,6 +524,25 @@ mod tests {
511524 assert_eq ! ( inbound, loaded_session) ;
512525 }
513526
527+ #[ async_test]
528+ async fn test_outbound_group_session_store ( ) {
529+ // Given an outbound sessions
530+ let ( account, _) = get_account_and_session_test_helper ( ) ;
531+ let room_id = room_id ! ( "!test:localhost" ) ;
532+ let ( outbound, _) = account. create_group_session_pair_with_defaults ( room_id) . await ;
533+
534+ // When we save it to the store
535+ let store = MemoryStore :: new ( ) ;
536+ store. save_outbound_group_sessions ( vec ! [ outbound. clone( ) ] ) ;
537+
538+ // Then we can get it out again
539+ let loaded_session = store. get_outbound_group_session ( room_id) . await . unwrap ( ) . unwrap ( ) ;
540+ assert_eq ! (
541+ serde_json:: to_string( & outbound. pickle( ) . await ) . unwrap( ) ,
542+ serde_json:: to_string( & loaded_session. pickle( ) . await ) . unwrap( )
543+ ) ;
544+ }
545+
514546 #[ async_test]
515547 async fn test_device_store ( ) {
516548 let device = get_device ( ) ;
0 commit comments