diff --git a/Cargo.lock b/Cargo.lock index 01c4900caf5..1daae66a8b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4639,7 +4639,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.12.6" -source = "git+https://github.com/ruma/ruma?rev=57049282e3a74f67f86e4eb2382a3e649b57cc2b#57049282e3a74f67f86e4eb2382a3e649b57cc2b" +source = "git+https://github.com/ruma/ruma?rev=51fb51a560027fd330b43398a278922cacbc825c#51fb51a560027fd330b43398a278922cacbc825c" dependencies = [ "assign", "js_int", @@ -4656,7 +4656,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.20.4" -source = "git+https://github.com/ruma/ruma?rev=57049282e3a74f67f86e4eb2382a3e649b57cc2b#57049282e3a74f67f86e4eb2382a3e649b57cc2b" +source = "git+https://github.com/ruma/ruma?rev=51fb51a560027fd330b43398a278922cacbc825c#51fb51a560027fd330b43398a278922cacbc825c" dependencies = [ "as_variant", "assign", @@ -4679,7 +4679,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.15.4" -source = "git+https://github.com/ruma/ruma?rev=57049282e3a74f67f86e4eb2382a3e649b57cc2b#57049282e3a74f67f86e4eb2382a3e649b57cc2b" +source = "git+https://github.com/ruma/ruma?rev=51fb51a560027fd330b43398a278922cacbc825c#51fb51a560027fd330b43398a278922cacbc825c" dependencies = [ "as_variant", "base64", @@ -4712,7 +4712,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.30.5" -source = "git+https://github.com/ruma/ruma?rev=57049282e3a74f67f86e4eb2382a3e649b57cc2b#57049282e3a74f67f86e4eb2382a3e649b57cc2b" +source = "git+https://github.com/ruma/ruma?rev=51fb51a560027fd330b43398a278922cacbc825c#51fb51a560027fd330b43398a278922cacbc825c" dependencies = [ "as_variant", "indexmap", @@ -4738,7 +4738,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.11.2" -source = "git+https://github.com/ruma/ruma?rev=57049282e3a74f67f86e4eb2382a3e649b57cc2b#57049282e3a74f67f86e4eb2382a3e649b57cc2b" +source = "git+https://github.com/ruma/ruma?rev=51fb51a560027fd330b43398a278922cacbc825c#51fb51a560027fd330b43398a278922cacbc825c" dependencies = [ "headers", "http", @@ -4758,7 +4758,7 @@ dependencies = [ [[package]] name = "ruma-html" version = "0.4.1" -source = "git+https://github.com/ruma/ruma?rev=57049282e3a74f67f86e4eb2382a3e649b57cc2b#57049282e3a74f67f86e4eb2382a3e649b57cc2b" +source = "git+https://github.com/ruma/ruma?rev=51fb51a560027fd330b43398a278922cacbc825c#51fb51a560027fd330b43398a278922cacbc825c" dependencies = [ "as_variant", "html5ever", @@ -4769,7 +4769,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.10.1" -source = "git+https://github.com/ruma/ruma?rev=57049282e3a74f67f86e4eb2382a3e649b57cc2b#57049282e3a74f67f86e4eb2382a3e649b57cc2b" +source = "git+https://github.com/ruma/ruma?rev=51fb51a560027fd330b43398a278922cacbc825c#51fb51a560027fd330b43398a278922cacbc825c" dependencies = [ "js_int", "thiserror 2.0.11", @@ -4778,7 +4778,7 @@ dependencies = [ [[package]] name = "ruma-macros" version = "0.15.2" -source = "git+https://github.com/ruma/ruma?rev=57049282e3a74f67f86e4eb2382a3e649b57cc2b#57049282e3a74f67f86e4eb2382a3e649b57cc2b" +source = "git+https://github.com/ruma/ruma?rev=51fb51a560027fd330b43398a278922cacbc825c#51fb51a560027fd330b43398a278922cacbc825c" dependencies = [ "cfg-if", "proc-macro-crate", @@ -4793,7 +4793,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.17.1" -source = "git+https://github.com/ruma/ruma?rev=57049282e3a74f67f86e4eb2382a3e649b57cc2b#57049282e3a74f67f86e4eb2382a3e649b57cc2b" +source = "git+https://github.com/ruma/ruma?rev=51fb51a560027fd330b43398a278922cacbc825c#51fb51a560027fd330b43398a278922cacbc825c" dependencies = [ "base64", "ed25519-dalek", diff --git a/Cargo.toml b/Cargo.toml index 7c3bb42fb93..18b29e1e7ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,7 +59,7 @@ proptest = { version = "1.6.0", default-features = false, features = ["std"] } rand = "0.8.5" reqwest = { version = "0.12.12", default-features = false } rmp-serde = "1.3.0" -ruma = { git = "https://github.com/ruma/ruma", rev = "57049282e3a74f67f86e4eb2382a3e649b57cc2b", features = [ +ruma = { git = "https://github.com/ruma/ruma", rev = "51fb51a560027fd330b43398a278922cacbc825c", features = [ "client-api-c", "compat-upload-signatures", "compat-arbitrary-length-ids", diff --git a/crates/matrix-sdk-base/src/lib.rs b/crates/matrix-sdk-base/src/lib.rs index 15cb0af9242..6d29168c25d 100644 --- a/crates/matrix-sdk-base/src/lib.rs +++ b/crates/matrix-sdk-base/src/lib.rs @@ -62,7 +62,7 @@ pub use room::{ }; pub use store::{ ComposerDraft, ComposerDraftType, QueueWedgeError, StateChanges, StateStore, StateStoreDataKey, - StateStoreDataValue, StoreError, + StateStoreDataValue, StoreError, ThreadSubscriptionCatchupToken, }; pub use utils::{ MinimalRoomMemberEvent, MinimalStateEvent, OriginalMinimalStateEvent, RedactedMinimalStateEvent, diff --git a/crates/matrix-sdk-base/src/store/memory_store.rs b/crates/matrix-sdk-base/src/store/memory_store.rs index f166475fc00..6ffedcf686a 100644 --- a/crates/matrix-sdk-base/src/store/memory_store.rs +++ b/crates/matrix-sdk-base/src/store/memory_store.rs @@ -46,7 +46,8 @@ use crate::{ MinimalRoomMemberEvent, RoomMemberships, StateStoreDataKey, StateStoreDataValue, deserialized_responses::{DisplayName, RawAnySyncOrStrippedState}, store::{ - QueueWedgeError, StoredThreadSubscription, traits::compare_thread_subscription_bump_stamps, + QueueWedgeError, StoredThreadSubscription, + traits::{ThreadSubscriptionCatchupToken, compare_thread_subscription_bump_stamps}, }, }; @@ -87,6 +88,7 @@ struct MemoryStoreInner { dependent_send_queue_events: BTreeMap>, seen_knock_requests: BTreeMap>, thread_subscriptions: BTreeMap>, + thread_subscriptions_catchup_tokens: Option>, } /// In-memory, non-persistent implementation of the `StateStore`. @@ -183,6 +185,10 @@ impl StateStore for MemoryStore { .get(room_id) .cloned() .map(StateStoreDataValue::SeenKnockRequests), + StateStoreDataKey::ThreadSubscriptionsCatchupTokens => inner + .thread_subscriptions_catchup_tokens + .clone() + .map(StateStoreDataValue::ThreadSubscriptionsCatchupTokens), }) } @@ -246,6 +252,12 @@ impl StateStore for MemoryStore { .expect("Session data is not a set of seen join request ids"), ); } + StateStoreDataKey::ThreadSubscriptionsCatchupTokens => { + inner.thread_subscriptions_catchup_tokens = + Some(value.into_thread_subscriptions_catchup_tokens().expect( + "Session data is not a list of thread subscription catchup tokens", + )); + } } Ok(()) @@ -276,6 +288,9 @@ impl StateStore for MemoryStore { StateStoreDataKey::SeenKnockRequests(room_id) => { inner.seen_knock_requests.remove(room_id); } + StateStoreDataKey::ThreadSubscriptionsCatchupTokens => { + inner.thread_subscriptions_catchup_tokens = None; + } } Ok(()) } diff --git a/crates/matrix-sdk-base/src/store/mod.rs b/crates/matrix-sdk-base/src/store/mod.rs index a28625c6715..1e5da059da5 100644 --- a/crates/matrix-sdk-base/src/store/mod.rs +++ b/crates/matrix-sdk-base/src/store/mod.rs @@ -93,7 +93,8 @@ pub use self::{ }, traits::{ ComposerDraft, ComposerDraftType, DynStateStore, IntoStateStore, ServerInfo, StateStore, - StateStoreDataKey, StateStoreDataValue, StateStoreExt, WellKnownResponse, + StateStoreDataKey, StateStoreDataValue, StateStoreExt, ThreadSubscriptionCatchupToken, + WellKnownResponse, }, }; diff --git a/crates/matrix-sdk-base/src/store/traits.rs b/crates/matrix-sdk-base/src/store/traits.rs index 783dc2dabeb..fe0eed45d20 100644 --- a/crates/matrix-sdk-base/src/store/traits.rs +++ b/crates/matrix-sdk-base/src/store/traits.rs @@ -1151,6 +1151,38 @@ pub enum StateStoreDataValue { /// A list of knock request ids marked as seen in a room. SeenKnockRequests(BTreeMap), + + /// A list of tokens to continue thread subscriptions catchup. + /// + /// See documentation of [`ThreadSubscriptionCatchupToken`] for more + /// details. + ThreadSubscriptionsCatchupTokens(Vec), +} + +/// Tokens to use when catching up on thread subscriptions. +/// +/// These tokens are created when the client receives some thread subscriptions +/// from sync, but the sync indicates that there are more thread subscriptions +/// available on the server. In this case, it's expected that the client will +/// call the [MSC4308] companion endpoint to catch up (back-paginate) on +/// previous thread subscriptions. +/// +/// [MSC4308]: https://github.com/matrix-org/matrix-spec-proposals/pull/4308 +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ThreadSubscriptionCatchupToken { + /// The token to use as the lower bound when fetching new threads + /// subscriptions. + /// + /// In sliding sync, this is the `prev_batch` value of a sliding sync + /// response. + pub from: String, + + /// The token to use as the upper bound when fetching new threads + /// subscriptions. + /// + /// In sliding sync, it must be set to the `pos` value of the sliding sync + /// *request*, which response received a `prev_batch` token. + pub to: Option, } /// Current draft of the composer for the room. @@ -1222,6 +1254,14 @@ impl StateStoreDataValue { pub fn into_seen_knock_requests(self) -> Option> { as_variant!(self, Self::SeenKnockRequests) } + + /// Get this value if it is the data for the thread subscriptions catchup + /// tokens. + pub fn into_thread_subscriptions_catchup_tokens( + self, + ) -> Option> { + as_variant!(self, Self::ThreadSubscriptionsCatchupTokens) + } } /// A key for key-value data. @@ -1258,6 +1298,9 @@ pub enum StateStoreDataKey<'a> { /// A list of knock request ids marked as seen in a room. SeenKnockRequests(&'a RoomId), + + /// A list of thread subscriptions catchup tokens. + ThreadSubscriptionsCatchupTokens, } impl StateStoreDataKey<'_> { @@ -1294,6 +1337,11 @@ impl StateStoreDataKey<'_> { /// Key prefix to use for the /// [`SeenKnockRequests`][Self::SeenKnockRequests] variant. pub const SEEN_KNOCK_REQUESTS: &'static str = "seen_knock_requests"; + + /// Key prefix to use for the + /// [`ThreadSubscriptionsCatchupTokens`][Self::ThreadSubscriptionsCatchupTokens] variant. + pub const THREAD_SUBSCRIPTIONS_CATCHUP_TOKENS: &'static str = + "thread_subscriptions_catchup_tokens"; } /// Compare two thread subscription changes bump stamps, given a fixed room and diff --git a/crates/matrix-sdk-indexeddb/src/state_store/mod.rs b/crates/matrix-sdk-indexeddb/src/state_store/mod.rs index 4d43d0ece58..82728bf17d2 100644 --- a/crates/matrix-sdk-indexeddb/src/state_store/mod.rs +++ b/crates/matrix-sdk-indexeddb/src/state_store/mod.rs @@ -32,7 +32,7 @@ use matrix_sdk_base::{ StateStore, StoreError, StoredThreadSubscription, ThreadSubscriptionStatus, }, MinimalRoomMemberEvent, RoomInfo, RoomMemberships, StateStoreDataKey, StateStoreDataValue, - ROOM_VERSION_FALLBACK, ROOM_VERSION_RULES_FALLBACK, + ThreadSubscriptionCatchupToken, ROOM_VERSION_FALLBACK, ROOM_VERSION_RULES_FALLBACK, }; use matrix_sdk_store_encryption::{Error as EncryptionError, StoreCipher}; use ruma::{ @@ -439,6 +439,9 @@ impl IndexeddbStateStore { StateStoreDataKey::SeenKnockRequests(room_id) => { self.encode_key(keys::KV, (StateStoreDataKey::SEEN_KNOCK_REQUESTS, room_id)) } + StateStoreDataKey::ThreadSubscriptionsCatchupTokens => { + self.encode_key(keys::KV, StateStoreDataKey::THREAD_SUBSCRIPTIONS_CATCHUP_TOKENS) + } } } } @@ -591,6 +594,10 @@ impl_state_store!({ .map(|f| self.deserialize_value::>(&f)) .transpose()? .map(StateStoreDataValue::SeenKnockRequests), + StateStoreDataKey::ThreadSubscriptionsCatchupTokens => value + .map(|f| self.deserialize_value::>(&f)) + .transpose()? + .map(StateStoreDataValue::ThreadSubscriptionsCatchupTokens), }; Ok(value) @@ -632,6 +639,11 @@ impl_state_store!({ .into_seen_knock_requests() .expect("Session data is not a set of seen knock request ids"), ), + StateStoreDataKey::ThreadSubscriptionsCatchupTokens => self.serialize_value( + &value + .into_thread_subscriptions_catchup_tokens() + .expect("Session data is not a list of thread subscription catchup tokens"), + ), }; let tx = diff --git a/crates/matrix-sdk-sqlite/src/state_store.rs b/crates/matrix-sdk-sqlite/src/state_store.rs index 1633c9aa442..c4c3ed3d71b 100644 --- a/crates/matrix-sdk-sqlite/src/state_store.rs +++ b/crates/matrix-sdk-sqlite/src/state_store.rs @@ -40,7 +40,7 @@ use ruma::{ use rusqlite::{OptionalExtension, Transaction}; use serde::{Deserialize, Serialize}; use tokio::fs; -use tracing::{debug, warn}; +use tracing::{debug, trace, warn}; use crate::{ error::{Error, Result}, @@ -416,6 +416,9 @@ impl SqliteStateStore { StateStoreDataKey::SeenKnockRequests(room_id) => { Cow::Owned(format!("{}:{room_id}", StateStoreDataKey::SEEN_KNOCK_REQUESTS)) } + StateStoreDataKey::ThreadSubscriptionsCatchupTokens => { + Cow::Borrowed(StateStoreDataKey::THREAD_SUBSCRIPTIONS_CATCHUP_TOKENS) + } }; self.encode_key(keys::KV_BLOB, &*key_s) @@ -1037,6 +1040,11 @@ impl StateStore for SqliteStateStore { StateStoreDataKey::SeenKnockRequests(_) => { StateStoreDataValue::SeenKnockRequests(self.deserialize_value(&data)?) } + StateStoreDataKey::ThreadSubscriptionsCatchupTokens => { + StateStoreDataValue::ThreadSubscriptionsCatchupTokens( + self.deserialize_value(&data)?, + ) + } }) }) .transpose() @@ -1077,6 +1085,11 @@ impl StateStore for SqliteStateStore { .into_seen_knock_requests() .expect("Session data is not a set of seen knock request ids"), )?, + StateStoreDataKey::ThreadSubscriptionsCatchupTokens => self.serialize_value( + &value + .into_thread_subscriptions_catchup_tokens() + .expect("Session data is not a list of thread subscription catchup tokens"), + )?, }; self.acquire() @@ -2125,9 +2138,11 @@ impl StateStore for SqliteStateStore { if let Some(previous) = self.load_thread_subscription(room_id, thread_id).await? { if previous == new { // No need to update anything. + trace!("not saving thread subscription because the subscription is the same"); return Ok(()); } if !compare_thread_subscription_bump_stamps(previous.bump_stamp, &mut new.bump_stamp) { + trace!("not saving thread subscription because we have a newer bump stamp"); return Ok(()); } } diff --git a/crates/matrix-sdk-ui/src/room_list_service/mod.rs b/crates/matrix-sdk-ui/src/room_list_service/mod.rs index 38a2a6c62e9..6fc9d3a634c 100644 --- a/crates/matrix-sdk-ui/src/room_list_service/mod.rs +++ b/crates/matrix-sdk-ui/src/room_list_service/mod.rs @@ -155,6 +155,15 @@ impl RoomListService { enabled: Some(true), })); + if client.enabled_thread_subscriptions() { + builder = builder.with_thread_subscriptions_extension( + assign!(http::request::ThreadSubscriptions::default(), { + enabled: Some(true), + limit: Some(ruma::uint!(10)) + }), + ); + } + if share_pos { // We don't deal with encryption device messages here so this is safe builder = builder.share_pos(); diff --git a/crates/matrix-sdk/src/client/builder/mod.rs b/crates/matrix-sdk/src/client/builder/mod.rs index af4b8d23d28..abeb36c150b 100644 --- a/crates/matrix-sdk/src/client/builder/mod.rs +++ b/crates/matrix-sdk/src/client/builder/mod.rs @@ -608,6 +608,7 @@ impl ClientBuilder { let event_cache = OnceCell::new(); let latest_events = OnceCell::new(); + let thread_subscriptions_catchup = OnceCell::new(); #[cfg(feature = "experimental-search")] let search_index = @@ -632,6 +633,7 @@ impl ClientBuilder { self.cross_process_store_locks_holder_name, #[cfg(feature = "experimental-search")] search_index, + thread_subscriptions_catchup, ) .await; diff --git a/crates/matrix-sdk/src/client/mod.rs b/crates/matrix-sdk/src/client/mod.rs index 2a5667af8bb..e7c82f821a6 100644 --- a/crates/matrix-sdk/src/client/mod.rs +++ b/crates/matrix-sdk/src/client/mod.rs @@ -86,6 +86,7 @@ use crate::{ matrix::MatrixAuth, oauth::OAuth, AuthCtx, AuthData, ReloadSessionCallback, SaveSessionCallback, }, + client::thread_subscriptions::ThreadSubscriptionCatchup, config::{RequestConfig, SyncToken}, deduplicating_handler::DeduplicatingHandler, error::HttpResult, @@ -117,6 +118,7 @@ pub(crate) mod caches; pub(crate) mod futures; #[cfg(feature = "experimental-search")] pub(crate) mod search; +pub(crate) mod thread_subscriptions; pub use self::builder::{sanitize_server_name, ClientBuildError, ClientBuilder}; #[cfg(feature = "experimental-search")] @@ -369,6 +371,10 @@ pub(crate) struct ClientInner { /// [`LatestEvent`]: crate::latest_event::LatestEvent latest_events: OnceCell, + /// Service handling the catching up of thread subscriptions in the + /// background. + thread_subscription_catchup: OnceCell>, + #[cfg(feature = "experimental-search")] /// Handler for [`RoomIndex`]'s of each room search_index: SearchIndex, @@ -397,6 +403,7 @@ impl ClientInner { #[cfg(feature = "e2e-encryption")] enable_share_history_on_invite: bool, cross_process_store_locks_holder_name: String, #[cfg(feature = "experimental-search")] search_index_handler: SearchIndex, + thread_subscription_catchup: OnceCell>, ) -> Arc { let caches = ClientCaches { server_info: server_info.into(), @@ -434,6 +441,7 @@ impl ClientInner { server_max_upload_size: Mutex::new(OnceCell::new()), #[cfg(feature = "experimental-search")] search_index: search_index_handler, + thread_subscription_catchup, }; #[allow(clippy::let_and_return)] @@ -452,6 +460,13 @@ impl ClientInner { }) .await; + let _ = client + .thread_subscription_catchup + .get_or_init(|| async { + ThreadSubscriptionCatchup::new(Client { inner: client.clone() }) + }) + .await; + client } } @@ -2770,6 +2785,7 @@ impl Client { cross_process_store_locks_holder_name, #[cfg(feature = "experimental-search")] self.inner.search_index.clone(), + self.inner.thread_subscription_catchup.clone(), ) .await, }; @@ -2913,6 +2929,10 @@ impl Client { }); Ok(self.send(request).await?) } + + pub(crate) fn thread_subscription_catchup(&self) -> &ThreadSubscriptionCatchup { + self.inner.thread_subscription_catchup.get().unwrap() + } } #[cfg(any(feature = "testing", test))] diff --git a/crates/matrix-sdk/src/client/thread_subscriptions.rs b/crates/matrix-sdk/src/client/thread_subscriptions.rs new file mode 100644 index 00000000000..2f5b927087c --- /dev/null +++ b/crates/matrix-sdk/src/client/thread_subscriptions.rs @@ -0,0 +1,441 @@ +// Copyright 2025 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + collections::BTreeMap, + sync::{ + atomic::{self, AtomicBool}, + Arc, + }, +}; + +use matrix_sdk_base::{ + executor::AbortOnDrop, + store::{StoredThreadSubscription, ThreadSubscriptionStatus}, + StateStoreDataKey, StateStoreDataValue, ThreadSubscriptionCatchupToken, +}; +use matrix_sdk_common::executor::spawn; +use once_cell::sync::OnceCell; +use ruma::{ + api::client::threads::get_thread_subscriptions_changes::unstable::{ + ThreadSubscription, ThreadUnsubscription, + }, + assign, OwnedEventId, OwnedRoomId, +}; +use tokio::sync::{ + mpsc::{channel, Receiver, Sender}, + Mutex, OwnedMutexGuard, +}; +use tracing::{instrument, trace, warn}; + +use crate::{client::WeakClient, Client, Result}; + +struct GuardedStoreAccess { + _mutex: OwnedMutexGuard<()>, + client: Client, + is_outdated: Arc, +} + +impl GuardedStoreAccess { + /// Return the current list of catchup tokens, if any. + /// + /// It is guaranteed that if the list is set, then it's non-empty. + async fn load_catchup_tokens(&self) -> Result>> { + let loaded = self + .client + .state_store() + .get_kv_data(StateStoreDataKey::ThreadSubscriptionsCatchupTokens) + .await?; + + match loaded { + Some(data) => { + if let Some(tokens) = data.into_thread_subscriptions_catchup_tokens() { + // If the tokens list is empty, automatically clean it up. + if tokens.is_empty() { + self.save_catchup_tokens(tokens).await?; + Ok(None) + } else { + Ok(Some(tokens)) + } + } else { + warn!( + "invalid data in thread subscriptions catchup tokens state store k/v entry" + ); + Ok(None) + } + } + + None => Ok(None), + } + } + + /// Saves the tokens in the database. + /// + /// Returns whether the list of tokens is empty or not. + #[instrument(skip_all, fields(num_tokens = tokens.len()))] + async fn save_catchup_tokens( + &self, + tokens: Vec, + ) -> Result { + let store = self.client.state_store(); + let is_empty = if tokens.is_empty() { + store.remove_kv_data(StateStoreDataKey::ThreadSubscriptionsCatchupTokens).await?; + + trace!("Marking thread subscriptions as not outdated \\o/"); + self.is_outdated.store(false, atomic::Ordering::SeqCst); + true + } else { + store + .set_kv_data( + StateStoreDataKey::ThreadSubscriptionsCatchupTokens, + StateStoreDataValue::ThreadSubscriptionsCatchupTokens(tokens), + ) + .await?; + + trace!("Marking thread subscriptions as outdated."); + self.is_outdated.store(true, atomic::Ordering::SeqCst); + false + }; + Ok(is_empty) + } +} + +pub struct ThreadSubscriptionCatchup { + /// The task catching up thread subscriptions in the background. + _task: OnceCell>, + + /// Whether the known list of thread subscriptions is outdated or not, i.e. + /// all thread subscriptions have been caught up + is_outdated: Arc, + + /// A weak reference to the parent [`Client`] instance. + client: WeakClient, + + /// A sender to wake up the catchup task when new catchup tokens are + /// available. + ping_sender: Sender<()>, + + /// A mutex to ensure there's only one writer on the thread subscriptions + /// catchup tokens at a time. + uniq_mutex: Arc>, +} + +impl ThreadSubscriptionCatchup { + pub fn new(client: Client) -> Arc { + let is_outdated = Arc::new(AtomicBool::new(true)); + + let weak_client = WeakClient::from_client(&client); + + let (ping_sender, ping_receiver) = channel(8); + + let uniq_mutex = Arc::new(Mutex::new(())); + + let this = Arc::new(Self { + _task: OnceCell::new(), + is_outdated, + client: weak_client, + ping_sender, + uniq_mutex, + }); + + // Create the task only if the client is configured to handle thread + // subscriptions. + if client.enabled_thread_subscriptions() { + let _ = this._task.get_or_init(|| { + AbortOnDrop::new(spawn(Self::thread_subscriptions_catchup_task( + this.clone(), + ping_receiver, + ))) + }); + } + + this + } + + /// Returns whether the known list of thread subscriptions is outdated or + /// more thread subscriptions need to be caught up. + pub(crate) fn is_outdated(&self) -> bool { + self.is_outdated.load(atomic::Ordering::SeqCst) + } + + /// Store the new subscriptions changes, received via the sync response or + /// from the msc4308 companion endpoint. + #[instrument(skip_all)] + pub(crate) async fn sync_subscriptions( + &self, + subscribed: BTreeMap>, + unsubscribed: BTreeMap>, + token: Option, + ) -> Result<()> { + let Some(guard) = self.lock().await else { + // Client is shutting down. + return Ok(()); + }; + self.save_catchup_token(&guard, token).await?; + self.store_subscriptions(&guard, subscribed, unsubscribed).await?; + Ok(()) + } + + async fn store_subscriptions( + &self, + guard: &GuardedStoreAccess, + subscribed: BTreeMap>, + unsubscribed: BTreeMap>, + ) -> Result<()> { + if subscribed.is_empty() && unsubscribed.is_empty() { + // Nothing to do. + return Ok(()); + } + + trace!( + "saving {} new subscriptions and {} unsubscriptions", + subscribed.values().map(|by_room| by_room.len()).sum::(), + unsubscribed.values().map(|by_room| by_room.len()).sum::(), + ); + + // Take into account the new unsubscriptions. + for (room_id, room_map) in unsubscribed { + for (event_id, thread_sub) in room_map { + guard + .client + .state_store() + .upsert_thread_subscription( + &room_id, + &event_id, + StoredThreadSubscription { + status: ThreadSubscriptionStatus::Unsubscribed, + bump_stamp: Some(thread_sub.bump_stamp.into()), + }, + ) + .await?; + } + } + + // Take into account the new subscriptions. + for (room_id, room_map) in subscribed { + for (event_id, thread_sub) in room_map { + guard + .client + .state_store() + .upsert_thread_subscription( + &room_id, + &event_id, + StoredThreadSubscription { + status: ThreadSubscriptionStatus::Subscribed { + automatic: thread_sub.automatic, + }, + bump_stamp: Some(thread_sub.bump_stamp.into()), + }, + ) + .await?; + } + } + + Ok(()) + } + + /// Internal helper to lock writes to the thread subscriptions catchup + /// tokens list. + async fn lock(&self) -> Option { + let client = self.client.get()?; + let mutex_guard = self.uniq_mutex.clone().lock_owned().await; + Some(GuardedStoreAccess { + _mutex: mutex_guard, + client, + is_outdated: self.is_outdated.clone(), + }) + } + + /// Save a new catchup token (or absence thereof) in the state store. + async fn save_catchup_token( + &self, + guard: &GuardedStoreAccess, + token: Option, + ) -> Result<()> { + // Note: saving an empty tokens list will mark the thread subscriptions list as + // not outdated. + let mut tokens = guard.load_catchup_tokens().await?.unwrap_or_default(); + + if let Some(token) = token { + trace!(?token, "Saving catchup token"); + tokens.push(token); + } else { + trace!("No catchup token to save"); + } + + let is_token_list_empty = guard.save_catchup_tokens(tokens).await?; + + // Wake up the catchup task, in case it's waiting. + if !is_token_list_empty { + let _ = self.ping_sender.send(()).await; + } + + Ok(()) + } + + /// The background task listening to new catchup tokens, and using them to + /// catch up the thread subscriptions via the [MSC4308] companion + /// endpoint. + /// + /// It will continue to process catchup tokens until there are none, and + /// then wait for a new one to be available and inserted in the + /// database. + /// + /// It always processes catch up tokens from the newest to the oldest, since + /// newest tokens are more interesting than older ones. Indeed, they're + /// more likely to include entries with higher bump-stamps, i.e. to include + /// more recent thread subscriptions statuses for each thread, so more + /// relevant information. + /// + /// [MSC4308]: https://github.com/matrix-org/matrix-spec-proposals/pull/4308 + #[instrument(skip_all)] + async fn thread_subscriptions_catchup_task(this: Arc, mut ping_receiver: Receiver<()>) { + loop { + // Load the current catchup token. + let Some(guard) = this.lock().await else { + // Client is shutting down. + return; + }; + + let store_tokens = match guard.load_catchup_tokens().await { + Ok(tokens) => tokens, + Err(err) => { + warn!("Failed to load thread subscriptions catchup tokens: {err}"); + continue; + } + }; + + let Some(mut tokens) = store_tokens else { + // Release the mutex. + drop(guard); + + // Wait for a wake up. + trace!("Waiting for an explicit wake up to process future thread subscriptions"); + + if let Some(()) = ping_receiver.recv().await { + trace!("Woke up!"); + continue; + } + + // Channel closed, the client is shutting down. + break; + }; + + // We do have a tokens. Pop the last value, and use it to catch up! + let last = tokens.pop().expect("must be set per `load_catchup_tokens` contract"); + + // Release the mutex before running the network request. + let client = guard.client.clone(); + drop(guard); + + // Start the actual catchup! + let req = assign!(ruma::api::client::threads::get_thread_subscriptions_changes::unstable::Request::new(), { + from: Some(last.from.clone()), + to: last.to.clone(), + }); + + match client.send(req).await { + Ok(resp) => { + let guard = this + .lock() + .await + .expect("a client instance is alive, so the locking should not fail"); + + if let Err(err) = + this.store_subscriptions(&guard, resp.subscribed, resp.unsubscribed).await + { + warn!("Failed to store caught up thread subscriptions: {err}"); + continue; + } + + // Refresh the tokens, as the list might have changed while we sent the + // request. + let mut tokens = match guard.load_catchup_tokens().await { + Ok(tokens) => tokens.unwrap_or_default(), + Err(err) => { + warn!("Failed to load thread subscriptions catchup tokens: {err}"); + continue; + } + }; + + let Some(index) = tokens.iter().position(|t| *t == last) else { + warn!("Thread subscriptions catchup token disappeared while processing it"); + continue; + }; + + if let Some(next_batch) = resp.end { + // If the response contained a next batch token, reuse the same catchup + // token entry, so the `to` value remains the same. + tokens[index] = + ThreadSubscriptionCatchupToken { from: next_batch, to: last.to }; + } else { + // No next batch, we can remove this token from the list. + tokens.remove(index); + } + + if let Err(err) = guard.save_catchup_tokens(tokens).await { + warn!("Failed to save updated thread subscriptions catchup tokens: {err}"); + } + } + + Err(err) => { + warn!("Failed to catch up thread subscriptions: {err}"); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use std::ops::Not as _; + + use matrix_sdk_base::ThreadSubscriptionCatchupToken; + use matrix_sdk_test::async_test; + + use crate::test_utils::client::MockClientBuilder; + + #[async_test] + async fn test_load_save_catchup_tokens() { + let client = MockClientBuilder::new(None).build().await; + + let tsc = client.thread_subscription_catchup(); + + // At first there are no catchup tokens, and we are outdated. + let guard = tsc.lock().await.unwrap(); + assert!(guard.load_catchup_tokens().await.unwrap().is_none()); + assert!(tsc.is_outdated()); + + // When I save a token, + let token = + ThreadSubscriptionCatchupToken { from: "from".to_owned(), to: Some("to".to_owned()) }; + guard.save_catchup_tokens(vec![token.clone()]).await.unwrap(); + + // Well, it is saved, + let tokens = guard.load_catchup_tokens().await.unwrap(); + assert_eq!(tokens, Some(vec![token])); + + // And we are still outdated. + assert!(tsc.is_outdated()); + + // When I remove the token, + guard.save_catchup_tokens(vec![]).await.unwrap(); + + // It is gone, + assert!(guard.load_catchup_tokens().await.unwrap().is_none()); + + // And we are not outdated anymore! + assert!(tsc.is_outdated().not()); + } +} diff --git a/crates/matrix-sdk/src/room/mod.rs b/crates/matrix-sdk/src/room/mod.rs index 179bf08c799..b04edce7341 100644 --- a/crates/matrix-sdk/src/room/mod.rs +++ b/crates/matrix-sdk/src/room/mod.rs @@ -4336,8 +4336,25 @@ impl Room { &self, thread_root: &EventId, ) -> Result> { - // A bit of a lie at the moment, since thread subscriptions are not sync'd yet. - self.fetch_thread_subscription(thread_root.to_owned()).await + // If the thread subscriptions list is outdated, fetch from the server. + if self.client.thread_subscription_catchup().is_outdated() { + return self.fetch_thread_subscription(thread_root.to_owned()).await; + } + + // Otherwise, we can rely on the store information. + Ok(self + .client + .state_store() + .load_thread_subscription(self.room_id(), thread_root) + .await + .map(|maybe_sub| { + maybe_sub.and_then(|stored| match stored.status { + ThreadSubscriptionStatus::Unsubscribed => None, + ThreadSubscriptionStatus::Subscribed { automatic } => { + Some(ThreadSubscription { automatic }) + } + }) + })?) } } diff --git a/crates/matrix-sdk/src/sliding_sync/builder.rs b/crates/matrix-sdk/src/sliding_sync/builder.rs index 38d7e389e52..6af5c8859b7 100644 --- a/crates/matrix-sdk/src/sliding_sync/builder.rs +++ b/crates/matrix-sdk/src/sliding_sync/builder.rs @@ -183,6 +183,23 @@ impl SlidingSyncBuilder { self } + /// Set the Threads subscriptions extension configuration. + pub fn with_thread_subscriptions_extension( + mut self, + thread_subscriptions: http::request::ThreadSubscriptions, + ) -> Self { + self.extensions.get_or_insert_with(Default::default).thread_subscriptions = + thread_subscriptions; + self + } + + /// Unset the Threads subscriptions extension configuration. + pub fn without_thread_subscriptions_extension(mut self) -> Self { + self.extensions.get_or_insert_with(Default::default).thread_subscriptions = + Default::default(); + self + } + /// Sets a custom timeout duration for the sliding sync polling endpoint. /// /// This is the maximum time to wait before the sliding sync server returns diff --git a/crates/matrix-sdk/src/sliding_sync/client.rs b/crates/matrix-sdk/src/sliding_sync/client.rs index 96da1724adc..176b89804aa 100644 --- a/crates/matrix-sdk/src/sliding_sync/client.rs +++ b/crates/matrix-sdk/src/sliding_sync/client.rs @@ -1,8 +1,13 @@ use std::collections::BTreeSet; -use matrix_sdk_base::{sync::SyncResponse, RequestedRequiredStates}; +use matrix_sdk_base::{ + sync::SyncResponse, RequestedRequiredStates, ThreadSubscriptionCatchupToken, +}; use matrix_sdk_common::deserialized_responses::ProcessedToDeviceEvent; -use ruma::api::{client::sync::sync_events::v5 as http, FeatureFlag, SupportedVersions}; +use ruma::api::{ + client::sync::sync_events::v5::{self as http, response}, + FeatureFlag, SupportedVersions, +}; use tracing::error; use super::{SlidingSync, SlidingSyncBuilder}; @@ -160,10 +165,7 @@ impl SlidingSyncResponseProcessor { } #[cfg(feature = "e2e-encryption")] - pub async fn handle_encryption( - &mut self, - extensions: &http::response::Extensions, - ) -> Result<()> { + pub async fn handle_encryption(&mut self, extensions: &response::Extensions) -> Result<()> { // This is an internal API misuse if this is triggered (calling // `handle_room_response` before this function), so panic is fine. assert!(self.response.is_none()); @@ -204,6 +206,25 @@ impl SlidingSyncResponseProcessor { Ok(()) } + pub async fn handle_thread_subscriptions( + &mut self, + previous_pos: Option<&str>, + thread_subs: response::ThreadSubscriptions, + ) -> Result<()> { + let catchup_token = + thread_subs.prev_batch.map(|prev_batch| ThreadSubscriptionCatchupToken { + from: prev_batch, + to: previous_pos.map(|s| s.to_owned()), + }); + + self.client + .thread_subscription_catchup() + .sync_subscriptions(thread_subs.subscribed, thread_subs.unsubscribed, catchup_token) + .await?; + + Ok(()) + } + pub async fn process_and_take_response(mut self) -> Result { let mut response = self.response.take().unwrap_or_default(); diff --git a/crates/matrix-sdk/src/sliding_sync/mod.rs b/crates/matrix-sdk/src/sliding_sync/mod.rs index 5cd0511a984..eae13a49efb 100644 --- a/crates/matrix-sdk/src/sliding_sync/mod.rs +++ b/crates/matrix-sdk/src/sliding_sync/mod.rs @@ -240,7 +240,7 @@ impl SlidingSync { #[instrument(skip_all)] async fn handle_response( &self, - sliding_sync_response: http::Response, + mut sliding_sync_response: http::Response, position: &mut SlidingSyncPositionMarkers, requested_required_states: RequestedRequiredStates, ) -> Result { @@ -272,6 +272,22 @@ impl SlidingSync { let mut response_processor = SlidingSyncResponseProcessor::new(self.inner.client.clone()); + // Process thread subscriptions if they're available. + // + // It's important to do this *before* handling the room responses, so that + // notifications can be properly generated based on the thread subscriptions, + // for the events in threads we've subscribed to. + if self.is_thread_subscriptions_enabled() { + response_processor + .handle_thread_subscriptions( + position.pos.as_deref(), + std::mem::take( + &mut sliding_sync_response.extensions.thread_subscriptions, + ), + ) + .await?; + } + #[cfg(feature = "e2e-encryption")] if self.is_e2ee_enabled() { response_processor.handle_encryption(&sliding_sync_response.extensions).await? @@ -623,6 +639,13 @@ impl SlidingSync { self.inner.sticky.read().unwrap().data().extensions.e2ee.enabled == Some(true) } + /// Is the thread subscriptions extension enabled for this sliding sync + /// instance? + fn is_thread_subscriptions_enabled(&self) -> bool { + self.inner.sticky.read().unwrap().data().extensions.thread_subscriptions.enabled + == Some(true) + } + #[cfg(not(feature = "e2e-encryption"))] fn is_e2ee_enabled(&self) -> bool { false @@ -636,15 +659,19 @@ impl SlidingSync { || !self.inner.lists.read().await.is_empty() } + /// Send a single sliding sync request, and returns the response summary. + /// + /// Public for testing purposes only. + #[doc(hidden)] #[instrument(skip_all, fields(pos, conn_id = self.inner.id))] - async fn sync_once(&self) -> Result { + pub async fn sync_once(&self) -> Result { let (request, request_config, position_guard) = self.generate_sync_request(&mut LazyTransactionId::new()).await?; - // Send the request, kaboom. + // Send the request. let summaries = self.send_sync_request(request, request_config, position_guard).await?; - // Notify a new sync was received + // Notify a new sync was received. self.inner.client.inner.sync_beat.notify(usize::MAX); Ok(summaries) diff --git a/crates/matrix-sdk/src/test_utils/mocks/mod.rs b/crates/matrix-sdk/src/test_utils/mocks/mod.rs index b72a28dcf3d..36ed4a2e613 100644 --- a/crates/matrix-sdk/src/test_utils/mocks/mod.rs +++ b/crates/matrix-sdk/src/test_utils/mocks/mod.rs @@ -33,6 +33,7 @@ use ruma::{ api::client::{ receipt::create_receipt::v3::ReceiptType, room::Visibility, + sync::sync_events::v5, threads::get_thread_subscriptions_changes::unstable::{ ThreadSubscription, ThreadUnsubscription, }, @@ -68,7 +69,7 @@ pub mod encryption; pub mod oauth; use super::client::MockClientBuilder; -use crate::{room::IncludeRelations, Client, OwnedServerName, Room}; +use crate::{room::IncludeRelations, Client, OwnedServerName, Room, SlidingSyncBuilder}; /// Structure used to store the crypto keys uploaded to the server. /// They will be served back to clients when requested. @@ -353,6 +354,13 @@ impl MatrixMockServer { ) } + /// Mocks the sliding sync endpoint. + pub fn mock_sliding_sync(&self) -> MockEndpoint<'_, SlidingSyncEndpoint> { + let mock = Mock::given(method("POST")) + .and(path("/_matrix/client/unstable/org.matrix.simplified_msc3575/sync")); + self.mock_endpoint(mock, SlidingSyncEndpoint) + } + /// Creates a prebuilt mock for joining a room. /// /// # Examples @@ -4172,6 +4180,8 @@ pub struct GetThreadSubscriptionsEndpoint { subscribed: BTreeMap>, /// New thread unsubscriptions per (room id, thread root event id). unsubscribed: BTreeMap>, + /// Optional delay to respond to the query. + delay: Option, } impl<'a> MockEndpoint<'a, GetThreadSubscriptionsEndpoint> { @@ -4197,6 +4207,12 @@ impl<'a> MockEndpoint<'a, GetThreadSubscriptionsEndpoint> { self } + /// Respond with a given delay to the query. + pub fn with_delay(mut self, delay: Duration) -> Self { + self.endpoint.delay = Some(delay); + self + } + /// Match the `from` query parameter to a given value. pub fn match_from(self, from: &str) -> Self { Self { mock: self.mock.and(query_param("from", from)), ..self } @@ -4214,7 +4230,14 @@ impl<'a> MockEndpoint<'a, GetThreadSubscriptionsEndpoint> { "unsubscribed": self.endpoint.unsubscribed, "end": end, }); - self.respond_with(ResponseTemplate::new(200).set_body_json(response_body)) + + let mut template = ResponseTemplate::new(200).set_body_json(response_body); + + if let Some(delay) = self.endpoint.delay { + template = template.set_delay(delay); + } + + self.respond_with(template) } } @@ -4280,3 +4303,39 @@ impl<'a> MockEndpoint<'a, GetHierarchyEndpoint> { }))) } } + +/// A prebuilt mock for running simplified sliding sync. +pub struct SlidingSyncEndpoint; + +impl<'a> MockEndpoint<'a, SlidingSyncEndpoint> { + /// Mocks the sliding sync endpoint with the given response. + pub fn ok(self, response: v5::Response) -> MatrixMock<'a> { + // A bit silly that we need to destructure all the fields ourselves, but + // Response isn't serializable :'( + self.respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "txn_id": response.txn_id, + "pos": response.pos, + "lists": response.lists, + "rooms": response.rooms, + "extensions": response.extensions, + }))) + } + + /// Temporarily mocks the sync with the given endpoint and runs a client + /// sync with it. + /// + /// After calling this function, the sync endpoint isn't mocked anymore. + pub async fn ok_and_run SlidingSyncBuilder>( + self, + client: &Client, + on_builder: F, + response: v5::Response, + ) { + let _scope = self.ok(response).mount_as_scoped().await; + + let sliding_sync = + on_builder(client.sliding_sync("test_id").unwrap()).build().await.unwrap(); + + let _summary = sliding_sync.sync_once().await.unwrap(); + } +} diff --git a/crates/matrix-sdk/tests/integration/client.rs b/crates/matrix-sdk/tests/integration/client.rs index 501dd086d0d..44cb0954c16 100644 --- a/crates/matrix-sdk/tests/integration/client.rs +++ b/crates/matrix-sdk/tests/integration/client.rs @@ -6,12 +6,13 @@ use futures_util::FutureExt; use matrix_sdk::{ authentication::oauth::{error::OAuthTokenRevocationError, OAuthError}, config::{RequestConfig, StoreConfig, SyncSettings, SyncToken}, - store::RoomLoadSettings, + sleep::sleep, + store::{RoomLoadSettings, ThreadSubscriptionStatus}, sync::{RoomUpdate, State}, test_utils::{ client::mock_matrix_session, mocks::MatrixMockServer, no_retry_test_client_with_server, }, - Client, Error, MemoryStore, StateChanges, StateStore, + Client, Error, MemoryStore, SlidingSyncList, StateChanges, StateStore, ThreadingSupport, }; use matrix_sdk_base::{sync::RoomUpdates, RoomState}; use matrix_sdk_common::executor::spawn; @@ -36,6 +37,10 @@ use ruma::{ get_public_rooms, get_public_rooms_filtered::{self, v3::Request as PublicRoomsFilterRequest}, }, + sync::sync_events::v5, + threads::get_thread_subscriptions_changes::unstable::{ + ThreadSubscription, ThreadUnsubscription, + }, uiaa, }, assign, device_id, @@ -53,7 +58,7 @@ use ruma::{ room::JoinRule, room_id, serde::Raw, - uint, user_id, OwnedUserId, + uint, user_id, EventId, OwnedUserId, RoomId, }; use serde_json::{json, Value as JsonValue}; use stream_assert::{assert_next_matches, assert_pending}; @@ -1511,8 +1516,6 @@ async fn test_room_sync_state_after() { #[async_test] async fn test_server_vendor_info() { - use matrix_sdk::test_utils::mocks::MatrixMockServer; - let server = MatrixMockServer::new().await; let client = server.client_builder().build().await; @@ -1527,8 +1530,6 @@ async fn test_server_vendor_info() { #[async_test] async fn test_server_vendor_info_with_missing_fields() { - use matrix_sdk::test_utils::mocks::MatrixMockServer; - let server = MatrixMockServer::new().await; let client = server.client_builder().build().await; @@ -1544,10 +1545,6 @@ async fn test_server_vendor_info_with_missing_fields() { #[async_test] async fn test_fetch_thread_subscriptions() { - use ruma::api::client::threads::get_thread_subscriptions_changes::unstable::{ - ThreadSubscription, ThreadUnsubscription, - }; - let server = MatrixMockServer::new().await; let client = server.client_builder().build().await; @@ -1592,3 +1589,211 @@ async fn test_fetch_thread_subscriptions() { let u = &response.unsubscribed[&room3][&thread3]; assert_eq!(u.bump_stamp, uint!(13)); } + +/// Create a sliding sync thread_subscription response with no `prev_batch` +/// token. +fn thread_subscription_response( + room1: &RoomId, + thread1: &EventId, + room2: &RoomId, + thread2: &EventId, +) -> v5::response::ThreadSubscriptions { + assign!(v5::response::ThreadSubscriptions::default(), { + subscribed: { + let mut map = BTreeMap::new(); + map.insert(room1.to_owned(), { + let mut threads = BTreeMap::new(); + threads.insert(thread1.to_owned(), ThreadSubscription::new(true, uint!(42))); + threads + }); + map + }, + unsubscribed: { + let mut map = BTreeMap::new(); + map.insert(room2.to_owned(), { + let mut threads = BTreeMap::new(); + threads.insert(thread2.to_owned(), ThreadUnsubscription::new(uint!(7))); + threads + }); + map + }, + prev_batch: None, + }) +} + +#[async_test] +async fn test_sync_thread_subscriptions() { + let server = MatrixMockServer::new().await; + let client = server.client_builder().build().await; + + let room1 = owned_room_id!("!room1:example.com"); + let room2 = owned_room_id!("!room2:example.com"); + + let thread1 = owned_event_id!("$thread1:example.com"); + let thread2 = owned_event_id!("$thread2:example.com"); + + // At first, there are no thread subscriptions at all. + let stored1 = client + .state_store() + .load_thread_subscription(&room1, &thread1) + .await + .expect("loading room1/thread1 works fine"); + assert_matches!(stored1, None); + + let stored2 = client + .state_store() + .load_thread_subscription(&room2, &thread2) + .await + .expect("loading room2/thread2 works fine"); + assert_matches!(stored2, None); + + // When I sliding-sync thread subscriptions, + server + .mock_sliding_sync() + .ok_and_run( + &client, + |config_builder| { + config_builder.with_thread_subscriptions_extension( + assign!(v5::request::ThreadSubscriptions::default(), { + enabled: Some(true), + limit: Some(uint!(10)), + }), + ) + }, + assign!(v5::Response::new("pos".to_owned()), { + extensions: assign!(v5::response::Extensions::default(), { + thread_subscriptions: thread_subscription_response( + &room1, &thread1, &room2, &thread2, + ), + }), + }), + ) + .await; + + // Then they're stored in the local database. + let stored1 = client + .state_store() + .load_thread_subscription(&room1, &thread1) + .await + .expect("loading room1/thread1 works fine") + .expect("found room1/thread1 subscription"); + + assert_eq!(stored1.status, ThreadSubscriptionStatus::Subscribed { automatic: true }); + assert_eq!(stored1.bump_stamp, Some(42)); + + let stored2 = client + .state_store() + .load_thread_subscription(&room2, &thread2) + .await + .expect("loading room2/thread2 works fine") + .expect("found room2/thread2 unsubscription"); + + assert_eq!(stored2.status, ThreadSubscriptionStatus::Unsubscribed); + assert_eq!(stored2.bump_stamp, Some(7)); +} + +#[async_test] +async fn test_sync_thread_subscriptions_with_catchup() { + let server = MatrixMockServer::new().await; + let client = server + .client_builder() + .on_builder(|builder| { + builder.with_threading_support(ThreadingSupport::Enabled { with_subscriptions: true }) + }) + .build() + .await; + + let room_id1 = owned_room_id!("!room1:example.com"); + let room_id2 = owned_room_id!("!room2:example.com"); + + let thread1 = owned_event_id!("$thread1:example.com"); + let thread2 = owned_event_id!("$thread2:example.com"); + let thread3 = owned_event_id!("$thread3:example.com"); + + // The provided catchup token will be used to fetch more thread + // subscriptions via the msc4308 companion endpoint. + server + .mock_get_thread_subscriptions() + .match_from("catchup_token") + .add_subscription( + room_id1.clone(), + thread3.clone(), + ThreadSubscription::new(false, uint!(1337)), + ) + .with_delay(Duration::from_millis(300)) // Simulate some network delay. + // No more subscriptions after the first catchup request. + .ok(None) + .mock_once() + .mount() + .await; + + // When I sliding-sync thread subscriptions, and the response includes this + // catch-up token, + let mut thread_subscriptions = + thread_subscription_response(&room_id1, &thread1, &room_id2, &thread2); + thread_subscriptions.prev_batch = Some("catchup_token".to_owned()); + + server + .mock_sliding_sync() + .ok_and_run( + &client, + |config_builder| { + config_builder + .with_thread_subscriptions_extension( + assign!(v5::request::ThreadSubscriptions::default(), { + enabled: Some(true), + limit: Some(uint!(10)), + }), + ) + .add_list(SlidingSyncList::builder("rooms")) + }, + assign!(v5::Response::new("pos".to_owned()), { + rooms: { + let mut rooms = BTreeMap::new(); + rooms.insert(room_id1.clone(), v5::response::Room::default()); + rooms.insert(room_id2.clone(), v5::response::Room::default()); + rooms + }, + extensions: assign!(v5::response::Extensions::default(), { + thread_subscriptions, + }), + }), + ) + .await; + + // If I try to get the subscription status for thread 1, it's still hitting + // network, because it doesn't know yet about the result of the catch-up + // request. (Ideally, the choice of whether some information is outdated or + // not would be per room/thread pair, but for simplicity it's global right + // now.) + server + .mock_room_get_thread_subscription() + .match_room_id(room_id1.clone()) + .match_thread_id(thread1.clone()) + .ok(true) + .mock_once() + .mount() + .await; + + let room1 = client.get_room(&room_id1).unwrap(); + let sub1 = room1.load_or_fetch_thread_subscription(&thread1).await.unwrap(); + assert_eq!(sub1, Some(matrix_sdk::room::ThreadSubscription { automatic: true })); + + // All the thread subscriptions are eventually known in the database. + sleep(Duration::from_millis(400)).await; + + let stored3 = client + .state_store() + .load_thread_subscription(&room_id1, &thread3) + .await + .expect("loading room1/thread3 works fine") + .expect("found room1/thread3 subscription"); + assert_eq!(stored3.status, ThreadSubscriptionStatus::Subscribed { automatic: false }); + assert_eq!(stored3.bump_stamp, Some(1337)); + + // So the client will use the database only to load_or_fetch thread + // subscriptions. (Which is confirmed by the absence of mocking the + // room_get_thread_subscription endpoint for thread3.) + let sub3 = room1.load_or_fetch_thread_subscription(&thread3).await.unwrap(); + assert_eq!(sub3, Some(matrix_sdk::room::ThreadSubscription { automatic: false })); +} diff --git a/labs/multiverse/src/widgets/room_view/mod.rs b/labs/multiverse/src/widgets/room_view/mod.rs index 3ef3c085784..fd24e6f9d3d 100644 --- a/labs/multiverse/src/widgets/room_view/mod.rs +++ b/labs/multiverse/src/widgets/room_view/mod.rs @@ -522,7 +522,7 @@ impl RoomView { async fn print_thread_subscription_status(&mut self) { if let TimelineKind::Thread { thread_root, .. } = &self.kind { self.call_with_room(async |room, status_handle| { - match room.fetch_thread_subscription(thread_root.clone()).await { + match room.load_or_fetch_thread_subscription(thread_root).await { Ok(Some(subscription)) => { status_handle.set_message(format!( "Thread subscription status: {}",