1- // Copyright 2023 The Matrix.org Foundation C.I.C.
1+ // Copyright 2023-2024 The Matrix.org Foundation C.I.C.
22//
33// Licensed under the Apache License, Version 2.0 (the "License");
44// you may not use this file except in compliance with the License.
1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15- use std:: { collections:: BTreeMap , time:: Duration } ;
15+ use std:: {
16+ collections:: { BTreeMap , HashSet } ,
17+ sync:: Arc ,
18+ time:: Duration ,
19+ } ;
1620
17- use futures_util:: future:: join_all;
1821use matrix_sdk_common:: failures_cache:: FailuresCache ;
1922use ruma:: {
2023 events:: room:: encrypted:: { EncryptedEventScheme , OriginalSyncRoomEncryptedEvent } ,
2124 serde:: Raw ,
2225 OwnedEventId , OwnedRoomId ,
2326} ;
24- use tokio:: sync:: mpsc:: { self , UnboundedReceiver } ;
25- use tracing:: { trace, warn} ;
27+ use tokio:: sync:: {
28+ mpsc:: { self , UnboundedReceiver } ,
29+ Mutex ,
30+ } ;
31+ use tracing:: { debug, trace, warn} ;
2632
2733use crate :: {
2834 client:: WeakClient ,
@@ -114,7 +120,6 @@ impl RoomKeyDownloadRequest {
114120}
115121
116122pub type RoomKeyInfo = ( OwnedRoomId , String ) ;
117- pub type TaskQueue = BTreeMap < RoomKeyInfo , JoinHandle < ( ) > > ;
118123
119124pub ( crate ) struct BackupDownloadTask {
120125 sender : mpsc:: UnboundedSender < RoomKeyDownloadRequest > ,
@@ -137,12 +142,7 @@ impl BackupDownloadTask {
137142 let ( sender, receiver) = mpsc:: unbounded_channel ( ) ;
138143
139144 let join_handle = spawn ( async move {
140- Self :: listen (
141- client,
142- receiver,
143- FailuresCache :: with_settings ( Duration :: from_secs ( 60 * 60 * 24 ) , 60 ) ,
144- )
145- . await ;
145+ Self :: listen ( client, receiver) . await ;
146146 } ) ;
147147
148148 Self { sender, join_handle }
@@ -169,70 +169,199 @@ impl BackupDownloadTask {
169169 }
170170 }
171171
172- pub ( crate ) async fn download (
173- client : Client ,
174- room_key_info : RoomKeyInfo ,
175- failures_cache : FailuresCache < RoomKeyInfo > ,
176- ) {
177- // Wait a bit, perhaps the room key will arrive in the meantime.
178- tokio:: time:: sleep ( Duration :: from_millis ( Self :: DOWNLOAD_DELAY_MILLIS ) ) . await ;
172+ /// Listen for incoming [`RoomKeyDownloadRequest`]s and process them.
173+ ///
174+ /// This will keep running until either the request channel is closed, or
175+ /// all other references to `Client` are dropped.
176+ ///
177+ /// # Arguments
178+ ///
179+ /// * `receiver` - The source of incoming [`RoomKeyDownloadRequest`]s.
180+ async fn listen ( client : WeakClient , mut receiver : UnboundedReceiver < RoomKeyDownloadRequest > ) {
181+ let state = Arc :: new ( Mutex :: new ( BackupDownloadTaskListenerState :: new ( client) ) ) ;
179182
180- if let Some ( machine ) = client . olm_machine ( ) . await . as_ref ( ) {
181- let ( room_id , session_id ) = & room_key_info ;
183+ while let Some ( room_key_download_request ) = receiver . recv ( ) . await {
184+ let mut state_guard = state . lock ( ) . await ;
182185
183- if !machine. is_room_key_available ( room_id, session_id) . await . unwrap ( ) {
184- match client. encryption ( ) . backups ( ) . download_room_key ( room_id, session_id) . await {
185- Ok ( _) => failures_cache. remove ( std:: iter:: once ( & room_key_info) ) ,
186- Err ( _) => failures_cache. insert ( room_key_info) ,
187- }
186+ if state_guard. client . strong_count ( ) == 0 {
187+ trace ! ( "Client got dropped, shutting down the task" ) ;
188+ break ;
189+ }
190+
191+ // Check that we don't already have a task to process this event, and fire one
192+ // off else if not.
193+ let event_id = & room_key_download_request. event_id ;
194+ if !state_guard. active_tasks . contains_key ( event_id) {
195+ let event_id = event_id. to_owned ( ) ;
196+ let task =
197+ spawn ( Self :: handle_download_request ( state. clone ( ) , room_key_download_request) ) ;
198+ state_guard. active_tasks . insert ( event_id, task) ;
188199 }
189200 }
190201 }
191202
192- pub ( crate ) async fn prune_tasks ( task_queue : & mut TaskQueue ) {
193- let mut handles = Vec :: with_capacity ( task_queue. len ( ) ) ;
203+ /// Handle a request to download a session for a given event.
204+ ///
205+ /// Sleeps for a while to see if the key turns up; then checks if we still
206+ /// want to do a download, and does the download if so.
207+ async fn handle_download_request (
208+ state : Arc < Mutex < BackupDownloadTaskListenerState > > ,
209+ download_request : RoomKeyDownloadRequest ,
210+ ) {
211+ // Wait a bit, perhaps the room key will arrive in the meantime.
212+ tokio:: time:: sleep ( Duration :: from_millis ( Self :: DOWNLOAD_DELAY_MILLIS ) ) . await ;
213+
214+ // Now take the lock, and check that we still want to do a download. If we do,
215+ // keep hold of a strong reference to the `Client`.
216+ let client = {
217+ let mut state = state. lock ( ) . await ;
218+
219+ let Some ( client) = state. client . get ( ) else {
220+ // The client was dropped while we were sleeping. We should just bail out;
221+ // the main BackupDownloadTask loop will bail out too.
222+ return ;
223+ } ;
224+
225+ // Check that we still want to do a download.
226+ if !state. should_download ( & client, & download_request) . await {
227+ // We decided against doing a download. Mark the job done for this event before
228+ // dropping the lock.
229+ state. active_tasks . remove ( & download_request. event_id ) ;
230+ return ;
231+ }
194232
195- while let Some ( ( _, handle) ) = task_queue. pop_first ( ) {
196- handles. push ( handle) ;
197- }
233+ // Before we drop the lock, indicate to other tasks that may be considering this
234+ // session that we're going to go ahead and do a download.
235+ state. downloaded_sessions . insert ( download_request. to_room_key_info ( ) ) ;
236+ client
237+ } ;
238+
239+ // Do the download without holding the lock.
240+ let result = client
241+ . encryption ( )
242+ . backups ( )
243+ . download_room_key ( & download_request. room_id , & download_request. megolm_session_id )
244+ . await ;
198245
199- join_all ( handles) . await ;
246+ // Then take the lock again to update the state.
247+ {
248+ let mut state = state. lock ( ) . await ;
249+ let room_key_info = download_request. to_room_key_info ( ) ;
250+ match result {
251+ Ok ( _) => {
252+ // We successfully downloaded the session. We can clear any record of previous
253+ // backoffs from the failures cache, because we won't be needing them again.
254+ state. failures_cache . remove ( std:: iter:: once ( & room_key_info) )
255+ }
256+ Err ( _) => {
257+ // We were unable to download the session. Update the failure cache so that we
258+ // back off from more requests, and also remove the entry from the list of
259+ // sessions that we are downloading.
260+ state. downloaded_sessions . remove ( & room_key_info) ;
261+ state. failures_cache . insert ( room_key_info) ;
262+ }
263+ }
264+ state. active_tasks . remove ( & download_request. event_id ) ;
265+ }
200266 }
267+ }
201268
202- pub ( crate ) async fn listen (
203- client : WeakClient ,
204- mut receiver : UnboundedReceiver < RoomKeyDownloadRequest > ,
205- failures_cache : FailuresCache < RoomKeyInfo > ,
206- ) {
207- let mut task_queue = TaskQueue :: new ( ) ;
269+ /// The state for an active [`BackupDownloadTask`].
270+ struct BackupDownloadTaskListenerState {
271+ /// Reference to the `Client`, which will be used to fire off the download
272+ /// requests.
273+ client : WeakClient ,
208274
209- while let Some ( room_key_download_request) = receiver. recv ( ) . await {
210- let room_key_info = room_key_download_request. to_room_key_info ( ) ;
211- trace ! ( ?room_key_info, "Got a request to download a room key from the backup" ) ;
275+ /// A record of backup download attempts that have recently failed.
276+ failures_cache : FailuresCache < RoomKeyInfo > ,
212277
213- if task_queue. len ( ) >= 10 {
214- Self :: prune_tasks ( & mut task_queue) . await
215- }
278+ /// Map from event ID to download task
279+ active_tasks : BTreeMap < OwnedEventId , JoinHandle < ( ) > > ,
216280
217- if let Some ( client) = client. get ( ) {
218- let backups = client. encryption ( ) . backups ( ) ;
281+ /// A list of megolm sessions that we have already downloaded, or are about
282+ /// to download.
283+ ///
284+ /// The idea here is that once we've (successfully) downloaded a session
285+ /// from the backup, there's not much point trying again even if we get
286+ /// another UTD event that uses the same session.
287+ ///
288+ /// TODO: that's not quite right though. In theory another client could
289+ /// update the backup with an earlier ratchet state, giving us access
290+ /// to earlier messages in the session. In which case, maybe this
291+ /// should expire?
292+ downloaded_sessions : HashSet < RoomKeyInfo > ,
293+ }
219294
220- let already_tried = failures_cache. contains ( & room_key_info) ;
221- let task_exists = task_queue. contains_key ( & room_key_info) ;
295+ impl BackupDownloadTaskListenerState {
296+ /// Prepare a new `BackupDownloadTaskListenerState`.
297+ ///
298+ /// # Arguments
299+ ///
300+ /// * `client` - A reference to the `Client`, which is used to fire off the
301+ /// backup download request.
302+ pub fn new ( client : WeakClient ) -> Self {
303+ Self {
304+ client,
305+ failures_cache : FailuresCache :: with_settings ( Duration :: from_secs ( 60 * 60 * 24 ) , 60 ) ,
306+ active_tasks : Default :: default ( ) ,
307+ downloaded_sessions : Default :: default ( ) ,
308+ }
309+ }
222310
223- if !already_tried && !task_exists && backups. are_enabled ( ) . await {
224- let task = spawn ( Self :: download (
225- client,
226- room_key_info. to_owned ( ) ,
227- failures_cache. to_owned ( ) ,
228- ) ) ;
311+ /// Check if we should set off a download for the given request.
312+ ///
313+ /// Checks if:
314+ /// * we already have the key,
315+ /// * we have already downloaded this session, or are about to do so, or
316+ /// * we've backed off from trying to download this session.
317+ ///
318+ /// If any of the above are true, returns `false`. Otherwise, returns
319+ /// `true`.
320+ pub async fn should_download (
321+ & self ,
322+ client : & Client ,
323+ download_request : & RoomKeyDownloadRequest ,
324+ ) -> bool {
325+ // Check that the Client has an OlmMachine
326+ let machine_guard = client. olm_machine ( ) . await ;
327+ let Some ( machine) = machine_guard. as_ref ( ) else {
328+ return false ;
329+ } ;
330+
331+ // Check if the keys for this message have arrived in the meantime.
332+ // If we get a StoreError doing the lookup, we assume the keys haven't arrived
333+ // (though if the store is returning errors, probably something else is
334+ // going to go wrong very soon).
335+ if machine
336+ . is_room_key_available ( & download_request. room_id , & download_request. megolm_session_id )
337+ . await
338+ . unwrap_or ( false )
339+ {
340+ debug ! ( ?download_request, "Not performing backup download because key became available while we were sleeping" ) ;
341+ return false ;
342+ }
229343
230- task_queue. insert ( room_key_info, task) ;
231- }
232- } else {
233- trace ! ( "Client got dropped, shutting down the task" ) ;
234- break ;
235- }
344+ // Check if we already downloaded this session, or another task is in the
345+ // process of doing so.
346+ let room_key_info = download_request. to_room_key_info ( ) ;
347+ if self . downloaded_sessions . contains ( & room_key_info) {
348+ debug ! (
349+ ?download_request,
350+ "Not performing backup download because this session has already been downloaded"
351+ ) ;
352+ return false ;
353+ } ;
354+
355+ // Check if we're backing off from attempts to download this session
356+ if self . failures_cache . contains ( & room_key_info) {
357+ debug ! (
358+ ?download_request,
359+ "Not performing backup download because this session failed to download recently"
360+ ) ;
361+ return false ;
236362 }
363+
364+ debug ! ( ?download_request, "Performing backup download" ) ;
365+ true
237366 }
238367}
0 commit comments