Skip to content

Commit 486b6d6

Browse files
committed
crypto: Save outbound sessions in MemoryStore
1 parent 32edfb1 commit 486b6d6

File tree

1 file changed

+38
-3
lines changed

1 file changed

+38
-3
lines changed

crates/matrix-sdk-crypto/src/store/memorystore.rs

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ pub struct MemoryStore {
5454
account: StdRwLock<Option<Account>>,
5555
sessions: SessionStore,
5656
inbound_group_sessions: GroupSessionStore,
57+
outbound_group_sessions: StdRwLock<Vec<OutboundGroupSession>>,
5758
olm_hashes: StdRwLock<HashMap<String, HashSet<String>>>,
5859
devices: DeviceStore,
5960
identities: StdRwLock<HashMap<OwnedUserId, ReadOnlyUserIdentities>>,
@@ -74,6 +75,7 @@ impl Default for MemoryStore {
7475
account: Default::default(),
7576
sessions: SessionStore::new(),
7677
inbound_group_sessions: GroupSessionStore::new(),
78+
outbound_group_sessions: Default::default(),
7779
olm_hashes: Default::default(),
7880
devices: DeviceStore::new(),
7981
identities: Default::default(),
@@ -119,6 +121,10 @@ impl MemoryStore {
119121
self.inbound_group_sessions.add(session);
120122
}
121123
}
124+
125+
fn save_outbound_group_sessions(&self, mut sessions: Vec<OutboundGroupSession>) {
126+
self.outbound_group_sessions.write().unwrap().append(&mut sessions);
127+
}
122128
}
123129

124130
type Result<T> = std::result::Result<T, Infallible>;
@@ -151,6 +157,7 @@ impl CryptoStore for MemoryStore {
151157
async fn save_changes(&self, changes: Changes) -> Result<()> {
152158
self.save_sessions(changes.sessions).await;
153159
self.save_inbound_group_sessions(changes.inbound_group_sessions);
160+
self.save_outbound_group_sessions(changes.outbound_group_sessions);
154161

155162
self.save_devices(changes.devices.new);
156163
self.save_devices(changes.devices.changed);
@@ -297,8 +304,17 @@ impl CryptoStore for MemoryStore {
297304
Ok(self.backup_keys.read().await.to_owned())
298305
}
299306

300-
async fn get_outbound_group_session(&self, _: &RoomId) -> Result<Option<OutboundGroupSession>> {
301-
Ok(None)
307+
async fn get_outbound_group_session(
308+
&self,
309+
room_id: &RoomId,
310+
) -> Result<Option<OutboundGroupSession>> {
311+
Ok(self
312+
.outbound_group_sessions
313+
.read()
314+
.unwrap()
315+
.iter()
316+
.find(|session| session.room_id() == room_id)
317+
.cloned())
302318
}
303319

304320
async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
@@ -487,7 +503,7 @@ mod tests {
487503
}
488504

489505
#[async_test]
490-
async fn test_group_session_store() {
506+
async fn test_inbound_group_session_store() {
491507
let (account, _) = get_account_and_session_test_helper();
492508
let room_id = room_id!("!test:localhost");
493509
let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw";
@@ -511,6 +527,25 @@ mod tests {
511527
assert_eq!(inbound, loaded_session);
512528
}
513529

530+
#[async_test]
531+
async fn test_outbound_group_session_store() {
532+
// Given an outbound sessions
533+
let (account, _) = get_account_and_session_test_helper();
534+
let room_id = room_id!("!test:localhost");
535+
let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
536+
537+
// When we save it to the store
538+
let store = MemoryStore::new();
539+
store.save_outbound_group_sessions(vec![outbound.clone()]);
540+
541+
// Then we can get it out again
542+
let loaded_session = store.get_outbound_group_session(room_id).await.unwrap().unwrap();
543+
assert_eq!(
544+
serde_json::to_string(&outbound.pickle().await).unwrap(),
545+
serde_json::to_string(&loaded_session.pickle().await).unwrap()
546+
);
547+
}
548+
514549
#[async_test]
515550
async fn test_device_store() {
516551
let device = get_device();

0 commit comments

Comments
 (0)