Skip to content
33 changes: 17 additions & 16 deletions src/api/client/room/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,22 +244,23 @@ pub(crate) async fn create_room_route(
.await?;

// 5.3 Guest Access
services
.timeline
.build_and_append_pdu(
PduBuilder::state(
String::new(),
&RoomGuestAccessEventContent::new(match preset {
| RoomPreset::PublicChat => GuestAccess::Forbidden,
| _ => GuestAccess::CanJoin,
}),
),
sender_user,
&room_id,
&state_lock,
)
.boxed()
.await?;
// Only send for non-public presets (matching Synapse's behavior where
// guest_can_join is true for private_chat and trusted_private_chat only)
if preset != RoomPreset::PublicChat {
services
.timeline
.build_and_append_pdu(
PduBuilder::state(
String::new(),
&RoomGuestAccessEventContent::new(GuestAccess::CanJoin),
),
sender_user,
&room_id,
&state_lock,
)
.boxed()
.await?;
}

// 6. Events listed in initial_state
let mut is_encrypted = false;
Expand Down
217 changes: 197 additions & 20 deletions src/api/server/send_join.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![expect(deprecated)]

use std::borrow::Borrow;
use std::{borrow::Borrow, collections::HashSet};

use axum::extract::State;
use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt, future::try_join4};
Expand All @@ -14,12 +14,13 @@ use ruma::{
};
use serde_json::value::RawValue as RawJsonValue;
use tuwunel_core::{
Err, Result, at, err,
matrix::event::gen_event_id_canonical_json,
utils::stream::{IterStream, TryBroadbandExt},
Err, Result, at, debug, err,
itertools::Itertools,
matrix::{Event, event::gen_event_id_canonical_json},
utils::stream::{BroadbandExt, IterStream, TryBroadbandExt},
warn,
};
use tuwunel_service::Services;
use tuwunel_service::{Services, rooms::short::ShortStateKey};

use crate::Ruma;

Expand All @@ -29,6 +30,7 @@ async fn create_join_event(
origin: &ServerName,
room_id: &RoomId,
pdu: &RawJsonValue,
omit_members: bool,
) -> Result<create_join_event::v1::RoomState> {
if !services.metadata.exists(room_id).await {
return Err!(Request(NotFound("Room is unknown to this server.")));
Expand Down Expand Up @@ -188,9 +190,150 @@ async fn create_join_event(
let state_ids = services
.state_accessor
.state_full_ids(shortstatehash)
.map(at!(1))
.collect::<Vec<OwnedEventId>>()
.boxed();
.collect::<Vec<_>>()
.boxed()
.shared();
// Filter out members if omit_members is true (MSC3706 + MSC3943)
let filtered_state_ids = if omit_members {
let joining_user_ssk = services
.short
.get_shortstatekey(&StateEventType::RoomMember, state_key.as_str())
.await
.ok();

state_ids
.clone()
.then(move |state| {
let joining_user_ssk = joining_user_ssk;
async move {
// MSC3943: Only include heroes when the room has no name and no
// canonical alias (matching Synapse's behavior in PR #14442).
let has_name = state
.iter()
.stream()
.any(|&(ssk, _)| async move {
services
.short
.get_statekey_from_short(ssk)
.await
.is_ok_and(|(et, sk)| {
et == StateEventType::RoomName && sk.is_empty()
})
})
.await;

let has_alias = state
.iter()
.stream()
.any(|&(ssk, _)| async move {
services
.short
.get_statekey_from_short(ssk)
.await
.is_ok_and(|(et, sk)| {
et == StateEventType::RoomCanonicalAlias && sk.is_empty()
})
})
.await;

// Collect hero SSKs only if room has no name and no canonical alias
let heroes_ssks: HashSet<ShortStateKey> = if !has_name && !has_alias {
// Classify members by membership state, excluding the joining
// user (matching Synapse's extract_heroes_from_room_summary).
let mut joined_invited: Vec<(ShortStateKey, String)> = Vec::new();
let mut left_banned: Vec<(ShortStateKey, String)> = Vec::new();

for &(ssk, ref eid) in &state {
let Ok((et, key)) = services.short.get_statekey_from_short(ssk).await
else {
continue;
};
if et != StateEventType::RoomMember {
continue;
}
// Exclude the joining user from heroes
if Some(ssk) == joining_user_ssk {
continue;
}
let Ok(pdu) = services.timeline.get_pdu(eid).await else {
continue;
};
let Ok(content) = serde_json::from_str::<RoomMemberEventContent>(
pdu.content().get(),
) else {
continue;
};
match content.membership {
| MembershipState::Join | MembershipState::Invite => {
joined_invited.push((ssk, key.to_string()));
},
| MembershipState::Leave | MembershipState::Ban => {
left_banned.push((ssk, key.to_string()));
},
| _ => {},
}
}

// Synapse: use joined+invited if any, otherwise fall back to
// left+banned. Sort by MXID, take first 5.
let heroes = if !joined_invited.is_empty() {
joined_invited
} else {
left_banned
};

heroes
.into_iter()
.sorted_by_key(|(_, key)| key.clone())
.map(|(ssk, _)| ssk)
.take(5)
.collect()
} else {
HashSet::new()
};

// Filter state: keep all non-member events, the joining user's
// member event, and hero member events. If get_statekey_from_short
// fails, keep the event (safe default, matching original behavior).
state
.iter()
.stream()
.broad_filter_map(move |&(ssk, ref eid)| {
let joining_user_ssk = joining_user_ssk;
let heroes_ssks = heroes_ssks.clone();
let eid = eid.clone();
async move {
let keep = services
.short
.get_statekey_from_short(ssk)
.await
.map(|(et, _)| {
et != StateEventType::RoomMember
|| Some(ssk) == joining_user_ssk || heroes_ssks
.contains(&ssk)
})
.unwrap_or(true); // safe default: keep unknown events

keep.then_some(eid)
}
})
.collect::<Vec<OwnedEventId>>()
.await
}
})
.boxed()
} else {
state_ids
.clone()
.map(|state| {
state
.iter()
.map(|(_, eid)| eid)
.cloned()
.collect::<Vec<_>>()
})
.boxed()
};

let mutex_lock = services
.event_handler
Expand Down Expand Up @@ -219,27 +362,47 @@ async fn create_join_event(
let broadcast = services.sending.send_pdu_room(room_id, &pdu_id);

// Wait for state gather which the remaining operations depend on.
let state_ids = state_ids.await;
let state_ids = filtered_state_ids.await;

let auth_heads = state_ids.iter().map(Borrow::borrow);

let into_federation_format = |pdu| {
services
.federation
.format_pdu_into(pdu, Some(&room_version))
.map(Ok)
};

let auth_chain = services
let auth_chain_ids: HashSet<OwnedEventId> = services
.auth_chain
.event_ids_iter(room_id, &room_version, auth_heads)
.try_collect()
.await?;

let state_ids_set: HashSet<OwnedEventId> = state_ids.iter().cloned().collect();

let auth_chain = auth_chain_ids
.into_iter()
.stream()
.map(Ok::<_, tuwunel_core::Error>)
.broad_and_then(async |event_id| {
services
.timeline
.get_pdu_json(&event_id)
.and_then(into_federation_format)
.await
// MSC3706: Any events returned within state can be omitted from auth_chain.
if omit_members && state_ids_set.contains(&event_id) {
return Ok(None);
}

let json = services.timeline.get_pdu_json(&event_id).await;

match json {
| Ok(pdu) => into_federation_format(pdu).await.map(Some),
| Err(e) => {
debug!(?event_id, "auth chain event not found: {e}");
Ok(None)
},
}
})
.try_filter_map(|opt_event| futures::future::ready(Ok(opt_event)))
.try_collect();

let state = state_ids
.iter()
.try_stream()
Expand Down Expand Up @@ -284,7 +447,7 @@ pub(crate) async fn create_join_event_v1_route(
}

Ok(create_join_event::v1::Response {
room_state: create_join_event(&services, body.origin(), &body.room_id, &body.pdu)
room_state: create_join_event(&services, body.origin(), &body.room_id, &body.pdu, false)
.boxed()
.await?,
})
Expand Down Expand Up @@ -315,18 +478,32 @@ pub(crate) async fn create_join_event_v2_route(
))));
}

// Get the servers in the room BEFORE the join
let servers_in_room = if body.omit_members {
Some(
services
.state_cache
.room_servers(&body.room_id)
.map(ToString::to_string)
.collect::<Vec<_>>()
.await,
)
} else {
None
};

let create_join_event::v1::RoomState { auth_chain, state, event } =
create_join_event(&services, body.origin(), &body.room_id, &body.pdu)
create_join_event(&services, body.origin(), &body.room_id, &body.pdu, body.omit_members)
.boxed()
.await?;

Ok(create_join_event::v2::Response {
room_state: create_join_event::v2::RoomState {
members_omitted: false,
members_omitted: body.omit_members,
auth_chain,
state,
event,
servers_in_room: None,
servers_in_room,
},
})
}
2 changes: 1 addition & 1 deletion tests/complement/results.jsonl
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@
{"Action":"pass","Test":"TestSearch/parallel/Search_results_with_recent_ordering_do_not_include_redacted_events"}
{"Action":"pass","Test":"TestSearch/parallel/Search_works_across_an_upgraded_room_and_its_predecessor"}
{"Action":"fail","Test":"TestSendAndFetchMessage"}
{"Action":"skip","Test":"TestSendJoinPartialStateResponse"}
{"Action":"pass","Test":"TestSendJoinPartialStateResponse"}
{"Action":"pass","Test":"TestSendMessageWithTxn"}
{"Action":"pass","Test":"TestServerCapabilities"}
{"Action":"skip","Test":"TestServerNotices"}
Expand Down
Loading