Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11804.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement [MSC3664](https://github.com/matrix-org/matrix-doc/pull/3664). Contributed by Nico.
16 changes: 16 additions & 0 deletions rust/src/push/base_rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::push::Action;
use crate::push::Condition;
use crate::push::EventMatchCondition;
use crate::push::PushRule;
use crate::push::RelatedEventMatchCondition;
use crate::push::SetTweak;
use crate::push::TweakValue;

Expand Down Expand Up @@ -114,6 +115,21 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/override/.im.nheko.msc3664.reply"),
priority_class: 5,
conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelatedEventMatch(
RelatedEventMatchCondition {
key: Some(Cow::Borrowed("sender")),
pattern: None,
pattern_type: Some(Cow::Borrowed("user_id")),
rel_type: Cow::Borrowed("m.in_reply_to"),
},
))]),
actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/override/.m.rule.contains_display_name"),
priority_class: 5,
Expand Down
91 changes: 89 additions & 2 deletions rust/src/push/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use regex::Regex;
use super::{
utils::{get_glob_matcher, get_localpart_from_id, GlobMatchType},
Action, Condition, EventMatchCondition, FilteredPushRules, KnownCondition,
RelatedEventMatchCondition,
};

lazy_static! {
Expand All @@ -49,6 +50,13 @@ pub struct PushRuleEvaluator {
/// The power level of the sender of the event, or None if event is an
/// outlier.
sender_power_level: Option<i64>,

/// The related events, indexed by relation type. Flattened in the same manner as
/// `flattened_keys`.
related_events_flattened: BTreeMap<String, BTreeMap<String, String>>,

/// If msc3664, push rules for related events, is enabled.
related_event_match_enabled: bool,
}

#[pymethods]
Expand All @@ -60,6 +68,8 @@ impl PushRuleEvaluator {
room_member_count: u64,
sender_power_level: Option<i64>,
notification_power_levels: BTreeMap<String, i64>,
related_events_flattened: BTreeMap<String, BTreeMap<String, String>>,
related_event_match_enabled: bool,
) -> Result<Self, Error> {
let body = flattened_keys
.get("content.body")
Expand All @@ -72,6 +82,8 @@ impl PushRuleEvaluator {
room_member_count,
notification_power_levels,
sender_power_level,
related_events_flattened,
related_event_match_enabled,
})
}

Expand Down Expand Up @@ -156,6 +168,9 @@ impl PushRuleEvaluator {
KnownCondition::EventMatch(event_match) => {
self.match_event_match(event_match, user_id)?
}
KnownCondition::RelatedEventMatch(event_match) => {
self.match_related_event_match(event_match, user_id)?
}
KnownCondition::ContainsDisplayName => {
if let Some(dn) = display_name {
if !dn.is_empty() {
Expand Down Expand Up @@ -239,6 +254,71 @@ impl PushRuleEvaluator {
compiled_pattern.is_match(haystack)
}

/// Evaluates a `related_event_match` condition. (MSC3664)
fn match_related_event_match(
&self,
event_match: &RelatedEventMatchCondition,
user_id: Option<&str>,
) -> Result<bool, Error> {
// First check if related event matching is enabled...
if !self.related_event_match_enabled {
return Ok(false);
}

// get the related event, fail if there is none.
let event = if let Some(event) = self.related_events_flattened.get(&*event_match.rel_type) {
event
} else {
return Ok(false);
};

// if we have no key, accept the event as matching, if it existed without matching any
// fields.
let key = if let Some(key) = &event_match.key {
key
} else {
return Ok(true);
};

let pattern = if let Some(pattern) = &event_match.pattern {
pattern
} else if let Some(pattern_type) = &event_match.pattern_type {
// The `pattern_type` can either be "user_id" or "user_localpart",
// either way if we don't have a `user_id` then the condition can't
// match.
let user_id = if let Some(user_id) = user_id {
user_id
} else {
return Ok(false);
};

match &**pattern_type {
"user_id" => user_id,
"user_localpart" => get_localpart_from_id(user_id)?,
_ => return Ok(false),
}
} else {
return Ok(false);
};

let haystack = if let Some(haystack) = event.get(&**key) {
haystack
} else {
return Ok(false);
};

// For the content.body we match against "words", but for everything
// else we match against the entire value.
let match_type = if key == "content.body" {
GlobMatchType::Word
} else {
GlobMatchType::Whole
};

let mut compiled_pattern = get_glob_matcher(pattern, match_type)?;
compiled_pattern.is_match(haystack)
}

/// Match the member count against an 'is' condition
/// The `is` condition can be things like '>2', '==3' or even just '4'.
fn match_member_count(&self, is: &str) -> Result<bool, Error> {
Expand Down Expand Up @@ -267,8 +347,15 @@ impl PushRuleEvaluator {
fn push_rule_evaluator() {
let mut flattened_keys = BTreeMap::new();
flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string());
let evaluator =
PushRuleEvaluator::py_new(flattened_keys, 10, Some(0), BTreeMap::new()).unwrap();
let evaluator = PushRuleEvaluator::py_new(
flattened_keys,
10,
Some(0),
BTreeMap::new(),
BTreeMap::new(),
true,
)
.unwrap();

let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"));
assert_eq!(result.len(), 3);
Expand Down
59 changes: 51 additions & 8 deletions rust/src/push/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ pub enum Condition {
#[serde(tag = "kind")]
pub enum KnownCondition {
EventMatch(EventMatchCondition),
#[serde(rename = "im.nheko.msc3664.related_event_match")]
RelatedEventMatch(RelatedEventMatchCondition),
ContainsDisplayName,
RoomMemberCount {
#[serde(skip_serializing_if = "Option::is_none")]
Expand Down Expand Up @@ -299,6 +301,18 @@ pub struct EventMatchCondition {
pub pattern_type: Option<Cow<'static, str>>,
}

/// The body of a [`Condition::RelatedEventMatch`]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct RelatedEventMatchCondition {
#[serde(skip_serializing_if = "Option::is_none")]
pub key: Option<Cow<'static, str>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pattern: Option<Cow<'static, str>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pattern_type: Option<Cow<'static, str>>,
pub rel_type: Cow<'static, str>,
}

/// The collection of push rules for a user.
#[derive(Debug, Clone, Default)]
#[pyclass(frozen)]
Expand Down Expand Up @@ -391,15 +405,21 @@ impl PushRules {
pub struct FilteredPushRules {
push_rules: PushRules,
enabled_map: BTreeMap<String, bool>,
msc3664_enabled: bool,
}

#[pymethods]
impl FilteredPushRules {
#[new]
pub fn py_new(push_rules: PushRules, enabled_map: BTreeMap<String, bool>) -> Self {
pub fn py_new(
push_rules: PushRules,
enabled_map: BTreeMap<String, bool>,
msc3664_enabled: bool,
) -> Self {
Self {
push_rules,
enabled_map,
msc3664_enabled,
}
}

Expand All @@ -414,13 +434,25 @@ impl FilteredPushRules {
/// Iterates over all the rules and their enabled state, including base
/// rules, in the order they should be executed in.
fn iter(&self) -> impl Iterator<Item = (&PushRule, bool)> {
self.push_rules.iter().map(|r| {
let enabled = *self
.enabled_map
.get(&*r.rule_id)
.unwrap_or(&r.default_enabled);
(r, enabled)
})
self.push_rules
.iter()
.filter(|rule| {
// Ignore disabled experimental push rules
if !self.msc3664_enabled
&& rule.rule_id == "global/override/.im.nheko.msc3664.reply"
{
return false;
}

true
})
.map(|r| {
let enabled = *self
.enabled_map
.get(&*r.rule_id)
.unwrap_or(&r.default_enabled);
(r, enabled)
})
}
}

Expand All @@ -446,6 +478,17 @@ fn test_deserialize_condition() {
let _: Condition = serde_json::from_str(json).unwrap();
}

#[test]
fn test_deserialize_unstable_condition() {
let json = r#"{"kind":"im.nheko.msc3664.related_event_match","key":"content.body","pattern":"coffee","rel_type":"m.in_reply_to"}"#;

let condition: Condition = serde_json::from_str(json).unwrap();
assert!(matches!(
condition,
Condition::Known(KnownCondition::RelatedEventMatch(_))
));
}

#[test]
fn test_deserialize_custom_condition() {
let json = r#"{"kind":"custom_tag"}"#;
Expand Down
6 changes: 5 additions & 1 deletion stubs/synapse/synapse_rust/push.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ class PushRules:
def rules(self) -> Collection[PushRule]: ...

class FilteredPushRules:
def __init__(self, push_rules: PushRules, enabled_map: Dict[str, bool]): ...
def __init__(
self, push_rules: PushRules, enabled_map: Dict[str, bool], msc3664_enabled: bool
): ...
def rules(self) -> Collection[Tuple[PushRule, bool]]: ...

def get_base_rule_ids() -> Collection[str]: ...
Expand All @@ -37,6 +39,8 @@ class PushRuleEvaluator:
room_member_count: int,
sender_power_level: Optional[int],
notification_power_levels: Mapping[str, int],
related_events_flattened: Mapping[str, Mapping[str, str]],
related_event_match_enabled: bool,
): ...
def run(
self,
Expand Down
3 changes: 3 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# MSC3773: Thread notifications
self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)

# MSC3664: Pushrules to match on related events
self.msc3664_enabled: bool = experimental.get("msc3664_enabled", False)

# MSC3848: Introduce errcodes for specific event sending failures
self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)

Expand Down
33 changes: 32 additions & 1 deletion synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@

logger = logging.getLogger(__name__)


push_rules_invalidation_counter = Counter(
"synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", ""
)
Expand Down Expand Up @@ -107,6 +106,8 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self._event_auth_handler = hs.get_event_auth_handler()

self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled

self.room_push_rule_cache_metrics = register_cache(
"cache",
"room_push_rule_cache",
Expand Down Expand Up @@ -225,6 +226,34 @@ async def action_for_event_by_user(
sender_power_level,
) = await self._get_power_levels_and_sender_level(event, context)

related_events: Dict[str, Dict[str, str]] = {}
if self._related_event_match_enabled:
related_event_id = event.content.get("m.relates_to", {}).get("event_id")
relation_type = event.content.get("m.relates_to", {}).get("rel_type")
if related_event_id is not None and relation_type is not None:
related_event = await self.store.get_event(
related_event_id, allow_none=True
)
if related_event is not None:
related_events[relation_type] = _flatten_dict(related_event)

reply_event_id = (
event.content.get("m.relates_to", {})
.get("m.in_reply_to", {})
.get("event_id")
)
if reply_event_id is not None and (
relation_type != "m.thread"
or not event.content.get("m.relates_to", {}).get(
"is_falling_back", False
)
):
related_event = await self.store.get_event(
reply_event_id, allow_none=True
)
if related_event is not None:
related_events["m.in_reply_to"] = _flatten_dict(related_event)

# Find the event's thread ID.
relation = relation_from_event(event)
# If the event does not have a relation, then it cannot have a thread ID.
Expand All @@ -250,6 +279,8 @@ async def action_for_event_by_user(
room_member_count,
sender_power_level,
notification_levels,
related_events,
self._related_event_match_enabled,
)

users = rules_by_user.keys()
Expand Down
5 changes: 5 additions & 0 deletions synapse/rest/client/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"enabled": True,
}

if self.config.experimental.msc3664_enabled:
response["capabilities"]["im.nheko.msc3664.related_event_match"] = {
"enabled": self.config.experimental.msc3664_enabled,
}

return HTTPStatus.OK, response


Expand Down
Loading