diff --git a/changelog.d/18762.feature b/changelog.d/18762.feature new file mode 100644 index 00000000000..aa8e91de01d --- /dev/null +++ b/changelog.d/18762.feature @@ -0,0 +1 @@ +Implement the push rules for experimental [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-doc/issues/4306). \ No newline at end of file diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 28537e187ec..96169fd45d9 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -61,6 +61,7 @@ fn bench_match_exact(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -71,10 +72,10 @@ fn bench_match_exact(b: &mut Bencher) { }, )); - let matched = eval.match_condition(&condition, None, None).unwrap(); + let matched = eval.match_condition(&condition, None, None, None).unwrap(); assert!(matched, "Didn't match"); - b.iter(|| eval.match_condition(&condition, None, None).unwrap()); + b.iter(|| eval.match_condition(&condition, None, None, None).unwrap()); } #[bench] @@ -107,6 +108,7 @@ fn bench_match_word(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -117,10 +119,10 @@ fn bench_match_word(b: &mut Bencher) { }, )); - let matched = eval.match_condition(&condition, None, None).unwrap(); + let matched = eval.match_condition(&condition, None, None, None).unwrap(); assert!(matched, "Didn't match"); - b.iter(|| eval.match_condition(&condition, None, None).unwrap()); + b.iter(|| eval.match_condition(&condition, None, None, None).unwrap()); } #[bench] @@ -153,6 +155,7 @@ fn bench_match_word_miss(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -163,10 +166,10 @@ fn bench_match_word_miss(b: &mut Bencher) { }, )); - let matched = eval.match_condition(&condition, None, None).unwrap(); + let matched = eval.match_condition(&condition, None, None, None).unwrap(); assert!(!matched, "Didn't match"); - b.iter(|| eval.match_condition(&condition, None, None).unwrap()); + b.iter(|| eval.match_condition(&condition, None, None, None).unwrap()); } #[bench] @@ -199,6 +202,7 @@ fn bench_eval_message(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -210,7 +214,8 @@ fn bench_eval_message(b: &mut Bencher) { false, false, false, + false, ); - b.iter(|| eval.run(&rules, Some("bob"), Some("person"))); + b.iter(|| eval.run(&rules, Some("bob"), Some("person"), None)); } diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index e0832ada1c7..ec027ca251e 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -290,6 +290,26 @@ pub const BASE_APPEND_CONTENT_RULES: &[PushRule] = &[PushRule { }]; pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ + PushRule { + rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.unsubscribed_thread"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known( + KnownCondition::Msc4306ThreadSubscription { subscribed: false }, + )]), + actions: Cow::Borrowed(&[]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.subscribed_thread"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known( + KnownCondition::Msc4306ThreadSubscription { subscribed: true }, + )]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/underride/.m.rule.call"), priority_class: 1, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index db406acb881..1cbca4c6355 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -106,8 +106,11 @@ pub struct PushRuleEvaluator { /// flag as MSC1767 (extensible events core). msc3931_enabled: bool, - // If MSC4210 (remove legacy mentions) is enabled. + /// If MSC4210 (remove legacy mentions) is enabled. msc4210_enabled: bool, + + /// If MSC4306 (thread subscriptions) is enabled. + msc4306_enabled: bool, } #[pymethods] @@ -126,6 +129,7 @@ impl PushRuleEvaluator { room_version_feature_flags, msc3931_enabled, msc4210_enabled, + msc4306_enabled, ))] pub fn py_new( flattened_keys: BTreeMap, @@ -138,6 +142,7 @@ impl PushRuleEvaluator { room_version_feature_flags: Vec, msc3931_enabled: bool, msc4210_enabled: bool, + msc4306_enabled: bool, ) -> Result { let body = match flattened_keys.get("content.body") { Some(JsonValue::Value(SimpleJsonValue::Str(s))) => s.clone().into_owned(), @@ -156,6 +161,7 @@ impl PushRuleEvaluator { room_version_feature_flags, msc3931_enabled, msc4210_enabled, + msc4306_enabled, }) } @@ -167,12 +173,19 @@ impl PushRuleEvaluator { /// /// Returns the set of actions, if any, that match (filtering out any /// `dont_notify` and `coalesce` actions). - #[pyo3(signature = (push_rules, user_id=None, display_name=None))] + /// + /// msc4306_thread_subscription_state: (Only populated if MSC4306 is enabled) + /// The thread subscription state corresponding to the thread containing this event. + /// - `None` if the event is not in a thread, or if MSC4306 is disabled. + /// - `Some(true)` if the event is in a thread and the user has a subscription for that thread + /// - `Some(false)` if the event is in a thread and the user does NOT have a subscription for that thread + #[pyo3(signature = (push_rules, user_id=None, display_name=None, msc4306_thread_subscription_state=None))] pub fn run( &self, push_rules: &FilteredPushRules, user_id: Option<&str>, display_name: Option<&str>, + msc4306_thread_subscription_state: Option, ) -> Vec { 'outer: for (push_rule, enabled) in push_rules.iter() { if !enabled { @@ -204,7 +217,12 @@ impl PushRuleEvaluator { Condition::Known(KnownCondition::RoomVersionSupports { feature: _ }), ); - match self.match_condition(condition, user_id, display_name) { + match self.match_condition( + condition, + user_id, + display_name, + msc4306_thread_subscription_state, + ) { Ok(true) => {} Ok(false) => continue 'outer, Err(err) => { @@ -237,14 +255,20 @@ impl PushRuleEvaluator { } /// Check if the given condition matches. - #[pyo3(signature = (condition, user_id=None, display_name=None))] + #[pyo3(signature = (condition, user_id=None, display_name=None, msc4306_thread_subscription_state=None))] fn matches( &self, condition: Condition, user_id: Option<&str>, display_name: Option<&str>, + msc4306_thread_subscription_state: Option, ) -> bool { - match self.match_condition(&condition, user_id, display_name) { + match self.match_condition( + &condition, + user_id, + display_name, + msc4306_thread_subscription_state, + ) { Ok(true) => true, Ok(false) => false, Err(err) => { @@ -262,6 +286,7 @@ impl PushRuleEvaluator { condition: &Condition, user_id: Option<&str>, display_name: Option<&str>, + msc4306_thread_subscription_state: Option, ) -> Result { let known_condition = match condition { Condition::Known(known) => known, @@ -393,6 +418,13 @@ impl PushRuleEvaluator { && self.room_version_feature_flags.contains(&flag) } } + KnownCondition::Msc4306ThreadSubscription { subscribed } => { + if !self.msc4306_enabled { + false + } else { + msc4306_thread_subscription_state == Some(*subscribed) + } + } }; Ok(result) @@ -536,10 +568,11 @@ fn push_rule_evaluator() { vec![], true, false, + false, ) .unwrap(); - let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); + let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"), None); assert_eq!(result.len(), 3); } @@ -566,6 +599,7 @@ fn test_requires_room_version_supports_condition() { flags, true, false, + false, ) .unwrap(); @@ -575,6 +609,7 @@ fn test_requires_room_version_supports_condition() { &FilteredPushRules::default(), Some("@bob:example.org"), None, + None, ); assert_eq!(result.len(), 3); @@ -593,7 +628,17 @@ fn test_requires_room_version_supports_condition() { }; let rules = PushRules::new(vec![custom_rule]); result = evaluator.run( - &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false, false), + &FilteredPushRules::py_new( + rules, + BTreeMap::new(), + true, + false, + true, + false, + false, + false, + ), + None, None, None, ); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index bd0e853ac31..b07a12e5ccd 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -369,6 +369,10 @@ pub enum KnownCondition { RoomVersionSupports { feature: Cow<'static, str>, }, + #[serde(rename = "io.element.msc4306.thread_subscription")] + Msc4306ThreadSubscription { + subscribed: bool, + }, } impl<'source> IntoPyObject<'source> for Condition { @@ -547,11 +551,13 @@ pub struct FilteredPushRules { msc3664_enabled: bool, msc4028_push_encrypted_events: bool, msc4210_enabled: bool, + msc4306_enabled: bool, } #[pymethods] impl FilteredPushRules { #[new] + #[allow(clippy::too_many_arguments)] pub fn py_new( push_rules: PushRules, enabled_map: BTreeMap, @@ -560,6 +566,7 @@ impl FilteredPushRules { msc3664_enabled: bool, msc4028_push_encrypted_events: bool, msc4210_enabled: bool, + msc4306_enabled: bool, ) -> Self { Self { push_rules, @@ -569,6 +576,7 @@ impl FilteredPushRules { msc3664_enabled, msc4028_push_encrypted_events, msc4210_enabled, + msc4306_enabled, } } @@ -619,6 +627,10 @@ impl FilteredPushRules { return false; } + if !self.msc4306_enabled && rule.rule_id.contains("/.io.element.msc4306.rule.") { + return false; + } + true }) .map(|r| { diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index da4fa29da70..bb9d5dbcaab 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -25,6 +25,7 @@ Any, Collection, Dict, + FrozenSet, List, Mapping, Optional, @@ -477,8 +478,18 @@ async def _action_for_event_by_user( event.room_version.msc3931_push_features, self.hs.config.experimental.msc1767_enabled, # MSC3931 flag self.hs.config.experimental.msc4210_enabled, + self.hs.config.experimental.msc4306_enabled, ) + msc4306_thread_subscribers: Optional[FrozenSet[str]] = None + if self.hs.config.experimental.msc4306_enabled and thread_id != MAIN_TIMELINE: + # pull out, in batch, all local subscribers to this thread + # (in the common case, they will all be getting processed for push + # rules right now) + msc4306_thread_subscribers = await self.store.get_subscribers_to_thread( + event.room_id, thread_id + ) + for uid, rules in rules_by_user.items(): if event.sender == uid: continue @@ -503,7 +514,13 @@ async def _action_for_event_by_user( # current user, it'll be added to the dict later. actions_by_user[uid] = [] - actions = evaluator.run(rules, uid, display_name) + msc4306_thread_subscription_state: Optional[bool] = None + if msc4306_thread_subscribers is not None: + msc4306_thread_subscription_state = uid in msc4306_thread_subscribers + + actions = evaluator.run( + rules, uid, display_name, msc4306_thread_subscription_state + ) if "notify" in actions: # Push rules say we should notify the user of this event actions_by_user[uid] = actions diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 22948f8c220..d6861405569 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -110,6 +110,7 @@ def _load_rules( msc3381_polls_enabled=experimental_config.msc3381_polls_enabled, msc4028_push_encrypted_events=experimental_config.msc4028_push_encrypted_events, msc4210_enabled=experimental_config.msc4210_enabled, + msc4306_enabled=experimental_config.msc4306_enabled, ) return filtered_rules diff --git a/synapse/storage/databases/main/thread_subscriptions.py b/synapse/storage/databases/main/thread_subscriptions.py index a99ef430717..24a99cf4490 100644 --- a/synapse/storage/databases/main/thread_subscriptions.py +++ b/synapse/storage/databases/main/thread_subscriptions.py @@ -14,6 +14,7 @@ from typing import ( TYPE_CHECKING, Any, + FrozenSet, Iterable, List, Optional, @@ -99,6 +100,7 @@ def process_replication_rows( self.get_subscription_for_thread.invalidate( (row.user_id, row.room_id, row.event_id) ) + self.get_subscribers_to_thread.invalidate((row.room_id, row.event_id)) super().process_replication_rows(stream_name, instance_name, token, rows) @@ -194,6 +196,16 @@ async def subscribe_user_to_thread( """ assert self._can_write_to_thread_subscriptions + def _invalidate_subscription_caches(txn: LoggingTransaction) -> None: + txn.call_after( + self.get_subscription_for_thread.invalidate, + (user_id, room_id, thread_root_event_id), + ) + txn.call_after( + self.get_subscribers_to_thread.invalidate, + (room_id, thread_root_event_id), + ) + def _subscribe_user_to_thread_txn( txn: LoggingTransaction, ) -> Optional[Union[int, AutomaticSubscriptionConflicted]]: @@ -234,10 +246,7 @@ def _subscribe_user_to_thread_txn( "unsubscribed_at_topological_ordering": None, }, ) - txn.call_after( - self.get_subscription_for_thread.invalidate, - (user_id, room_id, thread_root_event_id), - ) + _invalidate_subscription_caches(txn) return stream_id # we already have either a subscription or a prior unsubscription here @@ -291,10 +300,7 @@ def _subscribe_user_to_thread_txn( "unsubscribed_at_topological_ordering": None, }, ) - txn.call_after( - self.get_subscription_for_thread.invalidate, - (user_id, room_id, thread_root_event_id), - ) + _invalidate_subscription_caches(txn) return stream_id @@ -376,6 +382,10 @@ def _unsubscribe_user_from_thread_txn(txn: LoggingTransaction) -> Optional[int]: self.get_subscription_for_thread.invalidate, (user_id, room_id, thread_root_event_id), ) + txn.call_after( + self.get_subscribers_to_thread.invalidate, + (room_id, thread_root_event_id), + ) return stream_id @@ -388,7 +398,9 @@ async def purge_thread_subscription_settings_for_user(self, user_id: str) -> Non Purge all subscriptions for the user. The fact that subscriptions have been purged will not be streamed; all stream rows for the user will in fact be removed. - This is intended only for dealing with user deactivation. + + This must only be used for user deactivation, + because it does not invalidate the `subscribers_to_thread` cache. """ def _purge_thread_subscription_settings_for_user_txn( @@ -449,6 +461,42 @@ async def get_subscription_for_thread( return ThreadSubscription(automatic=automatic) + # max_entries=100 rationale: + # this returns a potentially large datastructure + # (since each entry contains a set which contains a potentially large number of user IDs), + # whereas the default of 10'000 entries for @cached feels more + # suitable for very small cache entries. + # + # Overall, when bearing in mind the usual profile of a small community-server or company-server + # (where cache tuning hasn't been done, so we're in out-of-box configuration), it is very + # unlikely we would benefit from keeping hot the subscribers for as many as 100 threads, + # since it's unlikely that so many threads will be active in a short span of time on a small homeserver. + # It feels that medium servers will probably also not exhaust this limit. + # Larger homeservers are more likely to be carefully tuned, either with a larger global cache factor + # or carefully following the usage patterns & cache metrics. + # Finally, the query is not so intensive that computing it every time is a huge deal, but given people + # often send messages back-to-back in the same thread it seems like it would offer a mild benefit. + @cached(max_entries=100) + async def get_subscribers_to_thread( + self, room_id: str, thread_root_event_id: str + ) -> FrozenSet[str]: + """ + Returns: + the set of user_ids for local users who are subscribed to the given thread. + """ + return frozenset( + await self.db_pool.simple_select_onecol( + table="thread_subscriptions", + keyvalues={ + "room_id": room_id, + "event_id": thread_root_event_id, + "subscribed": True, + }, + retcol="user_id", + desc="get_subscribers_to_thread", + ) + ) + def get_max_thread_subscriptions_stream_id(self) -> int: """Get the current maximum stream_id for thread subscriptions. diff --git a/synapse/synapse_rust/push.pyi b/synapse/synapse_rust/push.pyi index 3f317c32889..a3e12ad648e 100644 --- a/synapse/synapse_rust/push.pyi +++ b/synapse/synapse_rust/push.pyi @@ -49,6 +49,7 @@ class FilteredPushRules: msc3664_enabled: bool, msc4028_push_encrypted_events: bool, msc4210_enabled: bool, + msc4306_enabled: bool, ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... @@ -67,13 +68,19 @@ class PushRuleEvaluator: room_version_feature_flags: Tuple[str, ...], msc3931_enabled: bool, msc4210_enabled: bool, + msc4306_enabled: bool, ): ... def run( self, push_rules: FilteredPushRules, user_id: Optional[str], display_name: Optional[str], + msc4306_thread_subscription_state: Optional[bool], ) -> Collection[Union[Mapping, str]]: ... def matches( - self, condition: JsonDict, user_id: Optional[str], display_name: Optional[str] + self, + condition: JsonDict, + user_id: Optional[str], + display_name: Optional[str], + msc4306_thread_subscription_state: Optional[bool] = None, ) -> bool: ... diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 6c8c3a09de2..fad5c7affb2 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -26,7 +26,7 @@ from twisted.internet.testing import MemoryReactor -from synapse.api.constants import EventContentFields, RelationTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.room_versions import RoomVersions from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.rest import admin @@ -206,7 +206,10 @@ def test_action_for_event_by_user_disabled_by_config(self) -> None: bulk_evaluator._action_for_event_by_user.assert_not_called() def _create_and_process( - self, bulk_evaluator: BulkPushRuleEvaluator, content: Optional[JsonDict] = None + self, + bulk_evaluator: BulkPushRuleEvaluator, + content: Optional[JsonDict] = None, + type: str = "test", ) -> bool: """Returns true iff the `mentions` trigger an event push action.""" # Create a new message event which should cause a notification. @@ -214,7 +217,7 @@ def _create_and_process( self.event_creation_handler.create_event( self.requester, { - "type": "test", + "type": type, "room_id": self.room_id, "content": content or {}, "sender": f"@bob:{self.hs.hostname}", @@ -446,3 +449,73 @@ def test_suppress_edits(self) -> None: }, ) ) + + @override_config({"experimental_features": {"msc4306_enabled": True}}) + def test_thread_subscriptions(self) -> None: + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + (thread_root_id,) = self.helper.send_messages(self.room_id, 1, tok=self.token) + + self.assertFalse( + self._create_and_process( + bulk_evaluator, + { + "msgtype": "m.text", + "body": "test message before subscription", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + type=EventTypes.Message, + ) + ) + + self.get_success( + self.hs.get_datastores().main.subscribe_user_to_thread( + self.alice, + self.room_id, + thread_root_id, + automatic_event_orderings=None, + ) + ) + + self.assertTrue( + self._create_and_process( + bulk_evaluator, + { + "msgtype": "m.text", + "body": "test message after subscription", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + type="m.room.message", + ) + ) + + def test_with_disabled_thread_subscriptions(self) -> None: + """ + Test what happens with threaded events when MSC4306 is disabled. + + FUTURE: If MSC4306 becomes enabled-by-default/accepted, this test is to be removed. + """ + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + (thread_root_id,) = self.helper.send_messages(self.room_id, 1, tok=self.token) + + # When MSC4306 is not enabled, a threaded message generates a notification + # by default. + self.assertTrue( + self._create_and_process( + bulk_evaluator, + { + "msgtype": "m.text", + "body": "test message before subscription", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + type="m.room.message", + ) + ) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 8789c2f4cf1..3a351acffa5 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -150,6 +150,7 @@ def _get_evaluator( *, related_events: Optional[JsonDict] = None, msc4210: bool = False, + msc4306: bool = False, ) -> PushRuleEvaluator: event = FrozenEvent( { @@ -176,6 +177,7 @@ def _get_evaluator( room_version_feature_flags=event.room_version.msc3931_push_features, msc3931_enabled=True, msc4210_enabled=msc4210, + msc4306_enabled=msc4306, ) def test_display_name(self) -> None: @@ -806,6 +808,112 @@ def test_related_event_match_no_related_event(self) -> None: ) ) + def test_thread_subscription_subscribed(self) -> None: + """ + Test MSC4306 thread subscription push rules against an event in a subscribed thread. + """ + evaluator = self._get_evaluator( + { + "msgtype": "m.text", + "body": "Squawk", + "m.relates_to": { + "event_id": "$threadroot", + "rel_type": "m.thread", + }, + }, + msc4306=True, + ) + self.assertTrue( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": True, + }, + None, + None, + msc4306_thread_subscription_state=True, + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": False, + }, + None, + None, + msc4306_thread_subscription_state=True, + ) + ) + + def test_thread_subscription_unsubscribed(self) -> None: + """ + Test MSC4306 thread subscription push rules against an event in an unsubscribed thread. + """ + evaluator = self._get_evaluator( + { + "msgtype": "m.text", + "body": "Squawk", + "m.relates_to": { + "event_id": "$threadroot", + "rel_type": "m.thread", + }, + }, + msc4306=True, + ) + self.assertFalse( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": True, + }, + None, + None, + msc4306_thread_subscription_state=False, + ) + ) + self.assertTrue( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": False, + }, + None, + None, + msc4306_thread_subscription_state=False, + ) + ) + + def test_thread_subscription_unthreaded(self) -> None: + """ + Test MSC4306 thread subscription push rules against an unthreaded event. + """ + evaluator = self._get_evaluator( + {"msgtype": "m.text", "body": "Squawk"}, msc4306=True + ) + self.assertFalse( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": True, + }, + None, + None, + msc4306_thread_subscription_state=None, + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": False, + }, + None, + None, + msc4306_thread_subscription_state=None, + ) + ) + class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): """Tests for the bulk push rule evaluator""" diff --git a/tests/storage/test_thread_subscriptions.py b/tests/storage/test_thread_subscriptions.py index c09a4a9a441..2a5c440cf49 100644 --- a/tests/storage/test_thread_subscriptions.py +++ b/tests/storage/test_thread_subscriptions.py @@ -327,3 +327,42 @@ def test_should_skip_autosubscription_after_unsubscription(self) -> None: self.assertFalse( func(autosub=EventOrderings(-50, 2), unsubscribed_at=EventOrderings(2, 1)) ) + + def test_get_subscribers_to_thread(self) -> None: + """ + Test getting all subscribers to a thread at once. + + To check cache invalidations are correct, we do multiple + step-by-step rounds of subscription changes and assertions. + """ + other_user_id = "@other_user:test" + + subscribers = self.get_success( + self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id) + ) + self.assertEqual(subscribers, frozenset()) + + self._subscribe( + self.thread_root_id, automatic_event_orderings=None, user_id=self.user_id + ) + + subscribers = self.get_success( + self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id) + ) + self.assertEqual(subscribers, frozenset((self.user_id,))) + + self._subscribe( + self.thread_root_id, automatic_event_orderings=None, user_id=other_user_id + ) + + subscribers = self.get_success( + self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id) + ) + self.assertEqual(subscribers, frozenset((self.user_id, other_user_id))) + + self._unsubscribe(self.thread_root_id, user_id=self.user_id) + + subscribers = self.get_success( + self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id) + ) + self.assertEqual(subscribers, frozenset((other_user_id,)))