@@ -427,22 +427,55 @@ fn is_session_overshared_for_user(
427427 let recipient_device_ids: BTreeSet < & DeviceId > =
428428 recipient_devices. iter ( ) . map ( |d| d. device_id ( ) ) . collect ( ) ;
429429
430+ let mut shared: Vec < & DeviceId > = Vec :: new ( ) ;
431+
432+ // This duplicates a conservative subset of the logic in
433+ // `OutboundGroupSession::is_shared_with`, because we
434+ // don't have corresponding DeviceData at hand
435+ fn is_actually_shared ( info : & ShareInfo ) -> bool {
436+ match info {
437+ ShareInfo :: Shared ( _) => true ,
438+ ShareInfo :: Withheld ( _) => false ,
439+ }
440+ }
441+
442+ // Collect the devices that have definitely received the session already
430443 let guard = outbound_session. shared_with_set . read ( ) ;
444+ if let Some ( for_user) = guard. get ( user_id) {
445+ shared. extend ( for_user. iter ( ) . filter_map ( |( d, info) | {
446+ if is_actually_shared ( info) {
447+ Some ( AsRef :: < DeviceId > :: as_ref ( d) )
448+ } else {
449+ None
450+ }
451+ } ) ) ;
452+ }
453+
454+ // To be conservative, also collect the devices that would still receive the
455+ // session from a pending to-device request if we don't rotate beforehand
456+ let guard = outbound_session. to_share_with_set . read ( ) ;
457+ for ( _txid, share_infos) in guard. values ( ) {
458+ if let Some ( for_user) = share_infos. get ( user_id) {
459+ shared. extend ( for_user. iter ( ) . filter_map ( |( d, info) | {
460+ if is_actually_shared ( info) {
461+ Some ( AsRef :: < DeviceId > :: as_ref ( d) )
462+ } else {
463+ None
464+ }
465+ } ) ) ;
466+ }
467+ }
431468
432- let Some ( shared) = guard . get ( user_id ) else {
469+ if shared. is_empty ( ) {
433470 return false ;
434- } ;
471+ }
435472
436- // Devices that received this session
437- let shared: BTreeSet < & DeviceId > = shared
438- . iter ( )
439- . filter ( |( _, info) | matches ! ( info, ShareInfo :: Shared ( _) ) )
440- . map ( |( d, _) | d. as_ref ( ) )
441- . collect ( ) ;
473+ let shared: BTreeSet < & DeviceId > = shared. into_iter ( ) . collect ( ) ;
442474
443475 // The set difference between
444476 //
445- // 1. Devices that had previously received the session, and
477+ // 1. Devices that had previously received (or are queued to receive) the
478+ // session, and
446479 // 2. Devices that would now receive the session
447480 //
448481 // Represents newly deleted or blacklisted devices. If this
@@ -729,17 +762,21 @@ mod tests {
729762 } ,
730763 } ;
731764 use ruma:: {
732- device_id, events:: room:: history_visibility:: HistoryVisibility , room_id, TransactionId ,
765+ device_id,
766+ events:: { dummy:: ToDeviceDummyEventContent , room:: history_visibility:: HistoryVisibility } ,
767+ room_id, TransactionId ,
733768 } ;
734769 use serde_json:: json;
735770
736771 use crate :: {
737772 error:: SessionRecipientCollectionError ,
738- olm:: OutboundGroupSession ,
773+ olm:: { OutboundGroupSession , ShareInfo } ,
739774 session_manager:: {
740775 group_sessions:: share_strategy:: collect_session_recipients, CollectStrategy ,
741776 } ,
777+ store:: caches:: SequenceNumber ,
742778 testing:: simulate_key_query_response_for_verification,
779+ types:: requests:: ToDeviceRequest ,
743780 CrossSigningKeyExport , EncryptionSettings , LocalTrust , OlmError , OlmMachine ,
744781 } ;
745782
@@ -2136,6 +2173,61 @@ mod tests {
21362173 assert ! ( share_result. should_rotate) ;
21372174 }
21382175
2176+ /// Test that the session is rotated if a devices has a pending
2177+ /// to-device request that would share the keys with it.
2178+ #[ async_test]
2179+ async fn test_should_rotate_based_on_device_with_pending_request_excluded ( ) {
2180+ let machine = test_machine ( ) . await ;
2181+ import_known_users_to_test_machine ( & machine) . await ;
2182+
2183+ let encryption_settings = all_devices_strategy_settings ( ) ;
2184+ let group_session = create_test_outbound_group_session ( & machine, & encryption_settings) ;
2185+ let sender_key = machine. identity_keys ( ) . curve25519 ;
2186+
2187+ let dan_user = KeyDistributionTestData :: dan_id ( ) ;
2188+ let dan_dev1 = KeyDistributionTestData :: dan_signed_device_id ( ) ;
2189+ let dan_dev2 = KeyDistributionTestData :: dan_unsigned_device_id ( ) ;
2190+
2191+ // Share the session with device 1
2192+ group_session. mark_shared_with ( dan_user, dan_dev1, sender_key) . await ;
2193+
2194+ {
2195+ // Add a pending request to share with device 2
2196+ let share_infos = BTreeMap :: from ( [ (
2197+ dan_user. to_owned ( ) ,
2198+ BTreeMap :: from ( [ (
2199+ dan_dev2. to_owned ( ) ,
2200+ ShareInfo :: new_shared ( sender_key, 0 , SequenceNumber :: default ( ) ) ,
2201+ ) ] ) ,
2202+ ) ] ) ;
2203+
2204+ let txid = TransactionId :: new ( ) ;
2205+ let req = Arc :: new ( ToDeviceRequest :: for_recipients (
2206+ dan_user,
2207+ vec ! [ dan_dev2. to_owned( ) ] ,
2208+ & ruma:: events:: AnyToDeviceEventContent :: Dummy ( ToDeviceDummyEventContent ) ,
2209+ txid. clone ( ) ,
2210+ ) ) ;
2211+ group_session. add_request ( txid, req, share_infos) ;
2212+ }
2213+
2214+ // Remove device 2
2215+ let keys_query = KeyDistributionTestData :: dan_keys_query_response_device_loggedout ( ) ;
2216+ machine. mark_request_as_sent ( & TransactionId :: new ( ) , & keys_query) . await . unwrap ( ) ;
2217+
2218+ // Share again
2219+ let share_result = collect_session_recipients (
2220+ machine. store ( ) ,
2221+ vec ! [ KeyDistributionTestData :: dan_id( ) ] . into_iter ( ) ,
2222+ & encryption_settings,
2223+ & group_session,
2224+ )
2225+ . await
2226+ . unwrap ( ) ;
2227+
2228+ assert ! ( share_result. should_rotate) ;
2229+ }
2230+
21392231 /// Test that the session is not rotated if a devices is removed
21402232 /// but was already withheld from receiving the session.
21412233 #[ async_test]
0 commit comments