Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11804.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement [MSC3664](https://github.com/matrix-org/matrix-doc/pull/3664). Contributed by Nico.
38 changes: 38 additions & 0 deletions synapse/push/baserules.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,25 @@ def make_base_prepend_rules(
],
"actions": ["dont_notify"],
},
# Enable notifications for replies without fallback
{
"rule_id": "global/override/.im.nheko.msc3664.reply",
"conditions": [
# Only send notification if the reply is to your message
{
"kind": "im.nheko.msc3664.related_event_match",
"key": "sender",
"rel_type": "m.in_reply_to",
"pattern_type": "user_id",
"_id": "_reply",
},
],
"actions": [
"notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight"},
],
},
# This was changed from underride to override so it's closer in priority
# to the content rules where the user name highlight rule lives. This
# way a room rule is lower priority than both but a custom override rule
Expand Down Expand Up @@ -363,6 +382,25 @@ def make_base_prepend_rules(
],
"actions": ["notify", {"set_tweak": "sound", "value": "default"}],
},
# Enable notifications for replies without fallback
{
"rule_id": "global/override/.im.nheko.msc3664.reply",
"conditions": [
# Only send notification if the reply is to your message
{
"kind": "im.nheko.msc3664.related_event_match",
"key": "sender",
"rel_type": "m.in_reply_to",
"pattern_type": "user_id",
"_id": "_reply",
},
],
"actions": [
"notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight"},
],
},
{
"rule_id": "global/override/.m.rule.contains_display_name",
"conditions": [{"kind": "contains_display_name"}],
Expand Down
23 changes: 21 additions & 2 deletions synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@

logger = logging.getLogger(__name__)


push_rules_invalidation_counter = Counter(
"synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", ""
)
Expand Down Expand Up @@ -203,8 +202,28 @@ async def action_for_event_by_user(
sender_power_level,
) = await self._get_power_levels_and_sender_level(event, context)

related_events: Dict[str, EventBase] = {}
related_event_id = event.content.get("m.relates_to", {}).get("event_id")
relation_type = event.content.get("m.relates_to", {}).get("rel_type")
if related_event_id is not None and relation_type is not None:
related_event = await self.store.get_event(
related_event_id, allow_none=True
)
if related_event is not None:
related_events[relation_type] = related_event

reply_event_id = (
event.content.get("m.relates_to", {})
.get("m.in_reply_to", {})
.get("event_id")
)
if reply_event_id is not None:
related_event = await self.store.get_event(reply_event_id, allow_none=True)
if related_event is not None:
related_events["m.in_reply_to"] = related_event

evaluator = PushRuleEvaluatorForEvent(
event, len(room_members), sender_power_level, power_levels
event, len(room_members), sender_power_level, power_levels, related_events
)

condition_cache: Dict[str, bool] = {}
Expand Down
42 changes: 35 additions & 7 deletions synapse/push/push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
room_member_count: int,
sender_power_level: int,
power_levels: Dict[str, Union[int, Dict[str, int]]],
related_events: Dict[str, EventBase],
):
self._event = event
self._room_member_count = room_member_count
Expand All @@ -129,11 +130,35 @@ def __init__(
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)

self._related_events = related_events
self._related_events_value_cache = {
k: _flatten_dict(v) for k, v in related_events.items()
}

def matches(
self, condition: Dict[str, Any], user_id: str, display_name: Optional[str]
) -> bool:
if condition["kind"] == "event_match":
return self._event_match(condition, user_id)
return self._event_match(condition, user_id, self._event, self._value_cache)
elif condition["kind"] == "im.nheko.msc3664.related_event_match":
# If we have no related event, the pattern will never match
if not self._related_events:
return False

related_event = self._related_events.get(condition["rel_type"])
related_event_value_cache = self._related_events_value_cache.get(
condition["rel_type"]
)
if not related_event or not related_event_value_cache:
return False

# we have a related event, but we only want to match for existence
if not condition.get("key"):
return True

return self._event_match(
condition, user_id, related_event, related_event_value_cache
)
elif condition["kind"] == "contains_display_name":
return self._contains_display_name(display_name)
elif condition["kind"] == "room_member_count":
Expand All @@ -145,7 +170,13 @@ def matches(
else:
return True

def _event_match(self, condition: dict, user_id: str) -> bool:
def _event_match(
self,
condition: dict,
user_id: str,
event: EventBase,
event_value_cache: Dict[str, str],
) -> bool:
pattern = condition.get("pattern", None)

if not pattern:
Expand All @@ -161,13 +192,13 @@ def _event_match(self, condition: dict, user_id: str) -> bool:

# XXX: optimisation: cache our pattern regexps
if condition["key"] == "content.body":
body = self._event.content.get("body", None)
body = event.content.get("body", None)
if not body or not isinstance(body, str):
return False

return _glob_matches(pattern, body, word_boundary=True)
else:
haystack = self._get_value(condition["key"])
haystack = event_value_cache.get(condition["key"], None)
if haystack is None:
return False

Expand All @@ -191,9 +222,6 @@ def _contains_display_name(self, display_name: Optional[str]) -> bool:

return bool(r.search(body))

def _get_value(self, dotted_key: str) -> Optional[str]:
return self._value_cache.get(dotted_key, None)


# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(
Expand Down
149 changes: 147 additions & 2 deletions tests/push/test_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


class PushRuleEvaluatorTestCase(unittest.TestCase):
def _get_evaluator(self, content):
def _get_evaluator(self, content, related_events=None):
event = FrozenEvent(
{
"event_id": "$event_id",
Expand All @@ -39,7 +39,11 @@ def _get_evaluator(self, content):
sender_power_level = 0
power_levels = {}
return PushRuleEvaluatorForEvent(
event, room_member_count, sender_power_level, power_levels
event,
room_member_count,
sender_power_level,
power_levels,
{} if related_events is None else related_events,
)

def test_display_name(self):
Expand Down Expand Up @@ -266,3 +270,144 @@ def test_tweaks_for_actions(self):
push_rule_evaluator.tweaks_for_actions(actions),
{"sound": "default", "highlight": True},
)

def test_related_event_match(self):
evaluator = self._get_evaluator(
{
"m.relates_to": {
"event_id": "$parent_event_id",
"key": "\ud83d\udc4d\ufe0f",
"rel_type": "m.annotation",
"m.in_reply_to": {
"event_id": "$parent_event_id",
},
}
},
{
"m.in_reply_to": FrozenEvent(
{
"event_id": "$parent_event_id",
"type": "m.room.message",
"sender": "@other_user:test",
"room_id": "#room:test",
"content": {"msgtype": "m.text", "body": "Original message"},
},
RoomVersions.V1,
),
"m.annotation": FrozenEvent(
{
"event_id": "$parent_event_id",
"type": "m.room.message",
"sender": "@other_user:test",
"room_id": "#room:test",
"content": {"msgtype": "m.text", "body": "Original message"},
},
RoomVersions.V1,
),
},
)
self.assertFalse(
evaluator.matches(
{
"kind": "im.nheko.msc3664.related_event_match",
"key": "sender",
"rel_type": "m.in_reply_to",
"pattern_type": "user_id",
},
"@user:test",
"display_name",
)
)
self.assertTrue(
evaluator.matches(
{
"kind": "im.nheko.msc3664.related_event_match",
"key": "sender",
"rel_type": "m.in_reply_to",
"pattern_type": "user_id",
},
"@other_user:test",
"display_name",
)
)
self.assertTrue(
evaluator.matches(
{
"kind": "im.nheko.msc3664.related_event_match",
"key": "sender",
"rel_type": "m.annotation",
"pattern_type": "user_id",
},
"@other_user:test",
"display_name",
)
)
self.assertFalse(
evaluator.matches(
{
"kind": "im.nheko.msc3664.related_event_match",
"key": "sender",
"rel_type": "m.in_reply_to",
},
"@user:test",
"display_name",
)
)
self.assertTrue(
evaluator.matches(
{
"kind": "im.nheko.msc3664.related_event_match",
"rel_type": "m.in_reply_to",
},
"@user:test",
"display_name",
)
)
self.assertFalse(
evaluator.matches(
{
"kind": "im.nheko.msc3664.related_event_match",
"rel_type": "m.replace",
},
"@other_user:test",
"display_name",
)
)

def test_related_event_match_no_related_event(self):
evaluator = self._get_evaluator(
{"msgtype": "m.text", "body": "Message without related event"}
)
self.assertFalse(
evaluator.matches(
{
"kind": "im.nheko.msc3664.related_event_match",
"key": "sender",
"rel_type": "m.in_reply_to",
"pattern_type": "user_id",
},
"@user:test",
"display_name",
)
)
self.assertFalse(
evaluator.matches(
{
"kind": "im.nheko.msc3664.related_event_match",
"key": "sender",
"rel_type": "m.in_reply_to",
},
"@user:test",
"display_name",
)
)
self.assertFalse(
evaluator.matches(
{
"kind": "im.nheko.msc3664.related_event_match",
"rel_type": "m.in_reply_to",
},
"@user:test",
"display_name",
)
)