14
14
15
15
use std:: { collections:: BTreeMap , sync:: Arc , time:: Duration } ;
16
16
17
+ use futures_core:: Stream ;
18
+ use futures_util:: { pin_mut, StreamExt } ;
19
+ use matrix_sdk_base:: {
20
+ crypto:: store:: types:: RoomKeyBundleInfo , InviteAcceptanceDetails , RoomState ,
21
+ } ;
17
22
use matrix_sdk_common:: failures_cache:: FailuresCache ;
18
23
use ruma:: {
19
24
events:: room:: encrypted:: { EncryptedEventScheme , OriginalSyncRoomEncryptedEvent } ,
20
25
serde:: Raw ,
21
26
OwnedEventId , OwnedRoomId ,
22
27
} ;
23
- use tokio:: sync:: {
24
- mpsc:: { self , UnboundedReceiver } ,
25
- Mutex ,
26
- } ;
27
- use tracing:: { debug, trace, warn} ;
28
+ use tokio:: sync:: { mpsc, Mutex } ;
29
+ use tracing:: { debug, info, instrument, trace, warn} ;
28
30
29
31
use crate :: {
30
32
client:: WeakClient ,
31
33
encryption:: backups:: UploadState ,
32
34
executor:: { spawn, JoinHandle } ,
33
- Client ,
35
+ room:: shared_room_history,
36
+ Client , Room ,
34
37
} ;
35
38
36
39
/// A cache of room keys we already downloaded.
@@ -41,6 +44,7 @@ pub(crate) struct ClientTasks {
41
44
pub ( crate ) upload_room_keys : Option < BackupUploadingTask > ,
42
45
pub ( crate ) download_room_keys : Option < BackupDownloadTask > ,
43
46
pub ( crate ) update_recovery_state_after_backup : Option < JoinHandle < ( ) > > ,
47
+ pub ( crate ) receive_historic_room_key_bundles : Option < BundleReceiverTask > ,
44
48
pub ( crate ) setup_e2ee : Option < JoinHandle < ( ) > > ,
45
49
}
46
50
@@ -72,7 +76,7 @@ impl BackupUploadingTask {
72
76
let _ = self . sender . send ( ( ) ) ;
73
77
}
74
78
75
- pub ( crate ) async fn listen ( client : WeakClient , mut receiver : UnboundedReceiver < ( ) > ) {
79
+ pub ( crate ) async fn listen ( client : WeakClient , mut receiver : mpsc :: UnboundedReceiver < ( ) > ) {
76
80
while receiver. recv ( ) . await . is_some ( ) {
77
81
if let Some ( client) = client. get ( ) {
78
82
let upload_progress = & client. inner . e2ee . backup_state . upload_progress ;
@@ -176,7 +180,10 @@ impl BackupDownloadTask {
176
180
/// # Arguments
177
181
///
178
182
/// * `receiver` - The source of incoming [`RoomKeyDownloadRequest`]s.
179
- async fn listen ( client : WeakClient , mut receiver : UnboundedReceiver < RoomKeyDownloadRequest > ) {
183
+ async fn listen (
184
+ client : WeakClient ,
185
+ mut receiver : mpsc:: UnboundedReceiver < RoomKeyDownloadRequest > ,
186
+ ) {
180
187
let state = Arc :: new ( Mutex :: new ( BackupDownloadTaskListenerState :: new ( client) ) ) ;
181
188
182
189
while let Some ( room_key_download_request) = receiver. recv ( ) . await {
@@ -385,15 +392,97 @@ impl BackupDownloadTaskListenerState {
385
392
}
386
393
}
387
394
395
+ pub ( crate ) struct BundleReceiverTask {
396
+ _handle : JoinHandle < ( ) > ,
397
+ }
398
+
399
+ impl BundleReceiverTask {
400
+ pub async fn new ( client : & Client ) -> Self {
401
+ let stream = client. encryption ( ) . historic_room_key_stream ( ) . await . expect ( "E2EE tasks should only be initialized once we have logged in and have access to an OlmMachine" ) ;
402
+ let weak_client = WeakClient :: from_client ( client) ;
403
+ let handle = spawn ( Self :: listen_task ( weak_client, stream) ) ;
404
+
405
+ Self { _handle : handle }
406
+ }
407
+
408
+ async fn listen_task ( client : WeakClient , stream : impl Stream < Item = RoomKeyBundleInfo > ) {
409
+ pin_mut ! ( stream) ;
410
+
411
+ // TODO: Listening to this stream is not enough for iOS due to the NSE killing
412
+ // our OlmMachine and thus also this stream. We need to add an event handler
413
+ // that will listen for the bundle event. To be able to add an event handler,
414
+ // we'll have to implement the bundle event in Ruma.
415
+ while let Some ( bundle_info) = stream. next ( ) . await {
416
+ let Some ( client) = client. get ( ) else {
417
+ // The client was dropped while we were waiting on the stream. Let's end the
418
+ // loop, since this means that the application has shut down.
419
+ break ;
420
+ } ;
421
+
422
+ let Some ( room) = client. get_room ( & bundle_info. room_id ) else {
423
+ warn ! ( room_id = %bundle_info. room_id, "Received a historic room key bundle for an unknown room" ) ;
424
+ continue ;
425
+ } ;
426
+
427
+ Self :: handle_bundle ( & room, & bundle_info) . await ;
428
+ }
429
+ }
430
+
431
+ #[ instrument( skip( room) , fields( room_id = %room. room_id( ) ) ) ]
432
+ async fn handle_bundle ( room : & Room , bundle_info : & RoomKeyBundleInfo ) {
433
+ if Self :: should_accept_bundle ( room, bundle_info) {
434
+ info ! ( "Accepting a late key bundle." ) ;
435
+
436
+ if let Err ( e) =
437
+ shared_room_history:: maybe_accept_key_bundle ( room, & bundle_info. sender ) . await
438
+ {
439
+ warn ! ( "Couldn't accept a late room key bundle {e:?}" ) ;
440
+ }
441
+ } else {
442
+ info ! ( "Refusing to accept a historic room key bundle." ) ;
443
+ }
444
+ }
445
+
446
+ fn should_accept_bundle ( room : & Room , bundle_info : & RoomKeyBundleInfo ) -> bool {
447
+ // We accept historic room key bundles up to one day after we have accepted an
448
+ // invite.
449
+ const DAY : Duration = Duration :: from_secs ( 24 * 60 * 60 ) ;
450
+
451
+ // If we don't have any invite acceptance details, then this client wasn't the
452
+ // one that accepted the invite.
453
+ let Some ( InviteAcceptanceDetails { invite_accepted_at, inviter } ) =
454
+ room. invite_acceptance_details ( )
455
+ else {
456
+ return false ;
457
+ } ;
458
+
459
+ let state = room. state ( ) ;
460
+ let elapsed_since_join = invite_accepted_at. to_system_time ( ) . and_then ( |t| t. elapsed ( ) . ok ( ) ) ;
461
+ let bundle_sender = & bundle_info. sender ;
462
+
463
+ match ( state, elapsed_since_join) {
464
+ ( RoomState :: Joined , Some ( elapsed_since_join) ) => {
465
+ elapsed_since_join < DAY && bundle_sender == & inviter
466
+ }
467
+ ( RoomState :: Joined , None ) => false ,
468
+ ( RoomState :: Left | RoomState :: Invited | RoomState :: Knocked | RoomState :: Banned , _) => {
469
+ false
470
+ }
471
+ }
472
+ }
473
+ }
474
+
388
475
#[ cfg( all( test, not( target_family = "wasm" ) ) ) ]
389
476
mod test {
390
- use matrix_sdk_test:: async_test;
391
- use ruma:: { event_id, room_id} ;
477
+ use matrix_sdk_test:: {
478
+ async_test, event_factory:: EventFactory , InvitedRoomBuilder , JoinedRoomBuilder ,
479
+ } ;
480
+ use ruma:: { event_id, room_id, user_id} ;
392
481
use serde_json:: json;
393
482
use wiremock:: MockServer ;
394
483
395
484
use super :: * ;
396
- use crate :: test_utils:: logged_in_client;
485
+ use crate :: test_utils:: { logged_in_client, mocks :: MatrixMockServer } ;
397
486
398
487
// Test that, if backups are not enabled, we don't incorrectly mark a room key
399
488
// as downloaded.
@@ -451,4 +540,81 @@ mod test {
451
540
)
452
541
}
453
542
}
543
+
544
+ /// Test that ensures that we only accept a bundle if a certain set of
545
+ /// conditions is met.
546
+ #[ async_test]
547
+ async fn test_should_accept_bundle ( ) {
548
+ let server = MatrixMockServer :: new ( ) . await ;
549
+
550
+ let alice_user_id = user_id ! ( "@alice:localhost" ) ;
551
+ let bob_user_id = user_id ! ( "@bob:localhost" ) ;
552
+ let joined_room_id = room_id ! ( "!joined:localhost" ) ;
553
+ let invited_rom_id = room_id ! ( "!invited:localhost" ) ;
554
+
555
+ let client = server
556
+ . client_builder ( )
557
+ . logged_in_with_token ( "ABCD" . to_owned ( ) , alice_user_id. into ( ) , "DEVICEID" . into ( ) )
558
+ . build ( )
559
+ . await ;
560
+
561
+ let event_factory = EventFactory :: new ( ) . room ( invited_rom_id) ;
562
+ let bob_member_event = event_factory. member ( bob_user_id) . into_raw_timeline ( ) ;
563
+ let alice_member_event =
564
+ event_factory. member ( bob_user_id) . invited ( alice_user_id) . into_raw_timeline ( ) ;
565
+
566
+ server
567
+ . mock_sync ( )
568
+ . ok_and_run ( & client, |builder| {
569
+ builder. add_joined_room ( JoinedRoomBuilder :: new ( joined_room_id) ) . add_invited_room (
570
+ InvitedRoomBuilder :: new ( invited_rom_id)
571
+ . add_state_event ( bob_member_event. cast ( ) )
572
+ . add_state_event ( alice_member_event. cast ( ) ) ,
573
+ ) ;
574
+ } )
575
+ . await ;
576
+
577
+ let room =
578
+ client. get_room ( joined_room_id) . expect ( "We should have access to our joined room now" ) ;
579
+
580
+ assert ! (
581
+ room. invite_acceptance_details( ) . is_none( ) ,
582
+ "We shouldn't have any invite acceptance details if we didn't join the room on this Client"
583
+ ) ;
584
+
585
+ let bundle_info = RoomKeyBundleInfo {
586
+ sender : bob_user_id. to_owned ( ) ,
587
+ room_id : joined_room_id. to_owned ( ) ,
588
+ } ;
589
+
590
+ assert ! (
591
+ !BundleReceiverTask :: should_accept_bundle( & room, & bundle_info) ,
592
+ "We should not acceept a bundle if we did not join the room from this Client"
593
+ ) ;
594
+
595
+ let invited_room =
596
+ client. get_room ( invited_rom_id) . expect ( "We should have access to our invited room now" ) ;
597
+
598
+ assert ! (
599
+ !BundleReceiverTask :: should_accept_bundle( & invited_room, & bundle_info) ,
600
+ "We should not accept a bundle if we didn't join the room."
601
+ ) ;
602
+
603
+ server. mock_room_join ( invited_rom_id) . ok ( ) . mock_once ( ) . mount ( ) . await ;
604
+
605
+ let room = client
606
+ . join_room_by_id ( invited_rom_id)
607
+ . await
608
+ . expect ( "We should be able to join the invited room" ) ;
609
+
610
+ let details = room
611
+ . invite_acceptance_details ( )
612
+ . expect ( "We should have stored the invite acceptance details" ) ;
613
+ assert_eq ! ( details. inviter, bob_user_id, "We should have recorded that Bob has invited us" ) ;
614
+
615
+ assert ! (
616
+ BundleReceiverTask :: should_accept_bundle( & room, & bundle_info) ,
617
+ "We should accept a bundle if we just joined the room and did so from this very Client object"
618
+ ) ;
619
+ }
454
620
}
0 commit comments