Skip to content

Commit 565f974

Browse files
committed
sdk: Rewrite BackupDownloadTask
Change this so that it fires off a task for each UTD event, rather than each megolm session. This is a step towards considering the message index when deciding whether to carry on with the download.
1 parent 792402e commit 565f974

File tree

1 file changed

+189
-60
lines changed
  • crates/matrix-sdk/src/encryption

1 file changed

+189
-60
lines changed

crates/matrix-sdk/src/encryption/tasks.rs

Lines changed: 189 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2023 The Matrix.org Foundation C.I.C.
1+
// Copyright 2023-2024 The Matrix.org Foundation C.I.C.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -12,17 +12,23 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
use std::{collections::BTreeMap, time::Duration};
15+
use std::{
16+
collections::{BTreeMap, HashSet},
17+
sync::Arc,
18+
time::Duration,
19+
};
1620

17-
use futures_util::future::join_all;
1821
use matrix_sdk_common::failures_cache::FailuresCache;
1922
use ruma::{
2023
events::room::encrypted::{EncryptedEventScheme, OriginalSyncRoomEncryptedEvent},
2124
serde::Raw,
2225
OwnedEventId, OwnedRoomId,
2326
};
24-
use tokio::sync::mpsc::{self, UnboundedReceiver};
25-
use tracing::{trace, warn};
27+
use tokio::sync::{
28+
mpsc::{self, UnboundedReceiver},
29+
Mutex,
30+
};
31+
use tracing::{debug, trace, warn};
2632

2733
use crate::{
2834
client::WeakClient,
@@ -114,7 +120,6 @@ impl RoomKeyDownloadRequest {
114120
}
115121

116122
pub type RoomKeyInfo = (OwnedRoomId, String);
117-
pub type TaskQueue = BTreeMap<RoomKeyInfo, JoinHandle<()>>;
118123

119124
pub(crate) struct BackupDownloadTask {
120125
sender: mpsc::UnboundedSender<RoomKeyDownloadRequest>,
@@ -137,12 +142,7 @@ impl BackupDownloadTask {
137142
let (sender, receiver) = mpsc::unbounded_channel();
138143

139144
let join_handle = spawn(async move {
140-
Self::listen(
141-
client,
142-
receiver,
143-
FailuresCache::with_settings(Duration::from_secs(60 * 60 * 24), 60),
144-
)
145-
.await;
145+
Self::listen(client, receiver).await;
146146
});
147147

148148
Self { sender, join_handle }
@@ -169,70 +169,199 @@ impl BackupDownloadTask {
169169
}
170170
}
171171

172-
pub(crate) async fn download(
173-
client: Client,
174-
room_key_info: RoomKeyInfo,
175-
failures_cache: FailuresCache<RoomKeyInfo>,
176-
) {
177-
// Wait a bit, perhaps the room key will arrive in the meantime.
178-
tokio::time::sleep(Duration::from_millis(Self::DOWNLOAD_DELAY_MILLIS)).await;
172+
/// Listen for incoming [`RoomKeyDownloadRequest`]s and process them.
173+
///
174+
/// This will keep running until either the request channel is closed, or
175+
/// all other references to `Client` are dropped.
176+
///
177+
/// # Arguments
178+
///
179+
/// * `receiver` - The source of incoming [`RoomKeyDownloadRequest`]s.
180+
async fn listen(client: WeakClient, mut receiver: UnboundedReceiver<RoomKeyDownloadRequest>) {
181+
let state = Arc::new(Mutex::new(BackupDownloadTaskListenerState::new(client)));
179182

180-
if let Some(machine) = client.olm_machine().await.as_ref() {
181-
let (room_id, session_id) = &room_key_info;
183+
while let Some(room_key_download_request) = receiver.recv().await {
184+
let mut state_guard = state.lock().await;
182185

183-
if !machine.is_room_key_available(room_id, session_id).await.unwrap() {
184-
match client.encryption().backups().download_room_key(room_id, session_id).await {
185-
Ok(_) => failures_cache.remove(std::iter::once(&room_key_info)),
186-
Err(_) => failures_cache.insert(room_key_info),
187-
}
186+
if state_guard.client.strong_count() == 0 {
187+
trace!("Client got dropped, shutting down the task");
188+
break;
189+
}
190+
191+
// Check that we don't already have a task to process this event, and fire one
192+
// off else if not.
193+
let event_id = &room_key_download_request.event_id;
194+
if !state_guard.active_tasks.contains_key(event_id) {
195+
let event_id = event_id.to_owned();
196+
let task =
197+
spawn(Self::handle_download_request(state.clone(), room_key_download_request));
198+
state_guard.active_tasks.insert(event_id, task);
188199
}
189200
}
190201
}
191202

192-
pub(crate) async fn prune_tasks(task_queue: &mut TaskQueue) {
193-
let mut handles = Vec::with_capacity(task_queue.len());
203+
/// Handle a request to download a session for a given event.
204+
///
205+
/// Sleeps for a while to see if the key turns up; then checks if we still
206+
/// want to do a download, and does the download if so.
207+
async fn handle_download_request(
208+
state: Arc<Mutex<BackupDownloadTaskListenerState>>,
209+
download_request: RoomKeyDownloadRequest,
210+
) {
211+
// Wait a bit, perhaps the room key will arrive in the meantime.
212+
tokio::time::sleep(Duration::from_millis(Self::DOWNLOAD_DELAY_MILLIS)).await;
213+
214+
// Now take the lock, and check that we still want to do a download. If we do,
215+
// keep hold of a strong reference to the `Client`.
216+
let client = {
217+
let mut state = state.lock().await;
218+
219+
let Some(client) = state.client.get() else {
220+
// The client was dropped while we were sleeping. We should just bail out;
221+
// the main BackupDownloadTask loop will bail out too.
222+
return;
223+
};
224+
225+
// Check that we still want to do a download.
226+
if !state.should_download(&client, &download_request).await {
227+
// We decided against doing a download. Mark the job done for this event before
228+
// dropping the lock.
229+
state.active_tasks.remove(&download_request.event_id);
230+
return;
231+
}
194232

195-
while let Some((_, handle)) = task_queue.pop_first() {
196-
handles.push(handle);
197-
}
233+
// Before we drop the lock, indicate to other tasks that may be considering this
234+
// session that we're going to go ahead and do a download.
235+
state.downloaded_sessions.insert(download_request.to_room_key_info());
236+
client
237+
};
238+
239+
// Do the download without holding the lock.
240+
let result = client
241+
.encryption()
242+
.backups()
243+
.download_room_key(&download_request.room_id, &download_request.megolm_session_id)
244+
.await;
198245

199-
join_all(handles).await;
246+
// Then take the lock again to update the state.
247+
{
248+
let mut state = state.lock().await;
249+
let room_key_info = download_request.to_room_key_info();
250+
match result {
251+
Ok(_) => {
252+
// We successfully downloaded the session. We can clear any record of previous
253+
// backoffs from the failures cache, because we won't be needing them again.
254+
state.failures_cache.remove(std::iter::once(&room_key_info))
255+
}
256+
Err(_) => {
257+
// We were unable to download the session. Update the failure cache so that we
258+
// back off from more requests, and also remove the entry from the list of
259+
// sessions that we are downloading.
260+
state.downloaded_sessions.remove(&room_key_info);
261+
state.failures_cache.insert(room_key_info);
262+
}
263+
}
264+
state.active_tasks.remove(&download_request.event_id);
265+
}
200266
}
267+
}
201268

202-
pub(crate) async fn listen(
203-
client: WeakClient,
204-
mut receiver: UnboundedReceiver<RoomKeyDownloadRequest>,
205-
failures_cache: FailuresCache<RoomKeyInfo>,
206-
) {
207-
let mut task_queue = TaskQueue::new();
269+
/// The state for an active [`BackupDownloadTask`].
270+
struct BackupDownloadTaskListenerState {
271+
/// Reference to the `Client`, which will be used to fire off the download
272+
/// requests.
273+
client: WeakClient,
208274

209-
while let Some(room_key_download_request) = receiver.recv().await {
210-
let room_key_info = room_key_download_request.to_room_key_info();
211-
trace!(?room_key_info, "Got a request to download a room key from the backup");
275+
/// A record of backup download attempts that have recently failed.
276+
failures_cache: FailuresCache<RoomKeyInfo>,
212277

213-
if task_queue.len() >= 10 {
214-
Self::prune_tasks(&mut task_queue).await
215-
}
278+
/// Map from event ID to download task
279+
active_tasks: BTreeMap<OwnedEventId, JoinHandle<()>>,
216280

217-
if let Some(client) = client.get() {
218-
let backups = client.encryption().backups();
281+
/// A list of megolm sessions that we have already downloaded, or are about
282+
/// to download.
283+
///
284+
/// The idea here is that once we've (successfully) downloaded a session
285+
/// from the backup, there's not much point trying again even if we get
286+
/// another UTD event that uses the same session.
287+
///
288+
/// TODO: that's not quite right though. In theory another client could
289+
/// update the backup with an earlier ratchet state, giving us access
290+
/// to earlier messages in the session. In which case, maybe this
291+
/// should expire?
292+
downloaded_sessions: HashSet<RoomKeyInfo>,
293+
}
219294

220-
let already_tried = failures_cache.contains(&room_key_info);
221-
let task_exists = task_queue.contains_key(&room_key_info);
295+
impl BackupDownloadTaskListenerState {
296+
/// Prepare a new `BackupDownloadTaskListenerState`.
297+
///
298+
/// # Arguments
299+
///
300+
/// * `client` - A reference to the `Client`, which is used to fire off the
301+
/// backup download request.
302+
pub fn new(client: WeakClient) -> Self {
303+
Self {
304+
client,
305+
failures_cache: FailuresCache::with_settings(Duration::from_secs(60 * 60 * 24), 60),
306+
active_tasks: Default::default(),
307+
downloaded_sessions: Default::default(),
308+
}
309+
}
222310

223-
if !already_tried && !task_exists && backups.are_enabled().await {
224-
let task = spawn(Self::download(
225-
client,
226-
room_key_info.to_owned(),
227-
failures_cache.to_owned(),
228-
));
311+
/// Check if we should set off a download for the given request.
312+
///
313+
/// Checks if:
314+
/// * we already have the key,
315+
/// * we have already downloaded this session, or are about to do so, or
316+
/// * we've backed off from trying to download this session.
317+
///
318+
/// If any of the above are true, returns `false`. Otherwise, returns
319+
/// `true`.
320+
pub async fn should_download(
321+
&self,
322+
client: &Client,
323+
download_request: &RoomKeyDownloadRequest,
324+
) -> bool {
325+
// Check that the Client has an OlmMachine
326+
let machine_guard = client.olm_machine().await;
327+
let Some(machine) = machine_guard.as_ref() else {
328+
return false;
329+
};
330+
331+
// Check if the keys for this message have arrived in the meantime.
332+
// If we get a StoreError doing the lookup, we assume the keys haven't arrived
333+
// (though if the store is returning errors, probably something else is
334+
// going to go wrong very soon).
335+
if machine
336+
.is_room_key_available(&download_request.room_id, &download_request.megolm_session_id)
337+
.await
338+
.unwrap_or(false)
339+
{
340+
debug!(?download_request, "Not performing backup download because key became available while we were sleeping");
341+
return false;
342+
}
229343

230-
task_queue.insert(room_key_info, task);
231-
}
232-
} else {
233-
trace!("Client got dropped, shutting down the task");
234-
break;
235-
}
344+
// Check if we already downloaded this session, or another task is in the
345+
// process of doing so.
346+
let room_key_info = download_request.to_room_key_info();
347+
if self.downloaded_sessions.contains(&room_key_info) {
348+
debug!(
349+
?download_request,
350+
"Not performing backup download because this session has already been downloaded"
351+
);
352+
return false;
353+
};
354+
355+
// Check if we're backing off from attempts to download this session
356+
if self.failures_cache.contains(&room_key_info) {
357+
debug!(
358+
?download_request,
359+
"Not performing backup download because this session failed to download recently"
360+
);
361+
return false;
236362
}
363+
364+
debug!(?download_request, "Performing backup download");
365+
true
237366
}
238367
}

0 commit comments

Comments
 (0)