Skip to content

Commit a88d6b3

Browse files
committed
feat(event cache): also subscribe to a thread if we've sent a message into it
1 parent 705d6f8 commit a88d6b3

File tree

2 files changed

+257
-10
lines changed

2 files changed

+257
-10
lines changed

crates/matrix-sdk/src/event_cache/mod.rs

Lines changed: 144 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
#![forbid(missing_docs)]
2929

3030
use std::{
31-
collections::BTreeMap,
31+
collections::{BTreeMap, HashMap},
3232
fmt,
3333
sync::{Arc, OnceLock},
3434
};
@@ -44,6 +44,7 @@ use matrix_sdk_base::{
4444
},
4545
executor::AbortOnDrop,
4646
linked_chunk::{self, lazy_loader::LazyLoaderError, OwnedLinkedChunkId},
47+
store::SerializableEventContent,
4748
store_locks::LockStoreError,
4849
sync::RoomUpdates,
4950
timer,
@@ -52,7 +53,11 @@ use matrix_sdk_common::executor::{spawn, JoinHandle};
5253
#[cfg(feature = "experimental-search")]
5354
use matrix_sdk_search::error::IndexError;
5455
use room::RoomEventCacheState;
55-
use ruma::{events::AnySyncEphemeralRoomEvent, serde::Raw, OwnedEventId, OwnedRoomId, RoomId};
56+
use ruma::{
57+
events::{room::encrypted, AnySyncEphemeralRoomEvent},
58+
serde::Raw,
59+
OwnedEventId, OwnedRoomId, OwnedTransactionId, RoomId,
60+
};
5661
use tokio::{
5762
select,
5863
sync::{
@@ -62,7 +67,11 @@ use tokio::{
6267
};
6368
use tracing::{debug, error, info, info_span, instrument, trace, warn, Instrument as _, Span};
6469

65-
use crate::{client::WeakClient, Client};
70+
use crate::{
71+
client::WeakClient,
72+
send_queue::{LocalEchoContent, RoomSendQueueUpdate, SendQueueUpdate},
73+
Client,
74+
};
6675

6776
mod deduplicator;
6877
mod pagination;
@@ -425,6 +434,7 @@ impl EventCache {
425434
self.inner.generic_update_sender.subscribe()
426435
}
427436

437+
#[instrument(skip(client, thread_subscriber_sender))]
428438
async fn handle_thread_subscriber_linked_chunk_update(
429439
client: &WeakClient,
430440
thread_subscriber_sender: &Sender<()>,
@@ -508,7 +518,102 @@ impl EventCache {
508518
}
509519
}
510520

511-
return true;
521+
true
522+
}
523+
524+
#[instrument(skip(client, thread_subscriber_sender))]
525+
async fn handle_thread_subscriber_send_queue_update(
526+
client: &WeakClient,
527+
thread_subscriber_sender: &Sender<()>,
528+
events_being_sent: &mut HashMap<OwnedTransactionId, OwnedEventId>,
529+
up: SendQueueUpdate,
530+
) -> bool {
531+
let Some(client) = client.get() else {
532+
// Client shutting down.
533+
debug!("Client is shutting down, exiting thread subscriber task");
534+
return false;
535+
};
536+
537+
let room_id = up.room_id;
538+
let Some(room) = client.get_room(&room_id) else {
539+
warn!(%room_id, "unknown room");
540+
return true;
541+
};
542+
543+
let extract_thread_root = |serialized_event: SerializableEventContent| {
544+
match serialized_event.deserialize() {
545+
Ok(content) => {
546+
if let Some(encrypted::Relation::Thread(thread)) = content.relation() {
547+
return Some(thread.event_id);
548+
}
549+
}
550+
Err(err) => {
551+
warn!("error when deserializing content of a local echo: {err}");
552+
}
553+
}
554+
None
555+
};
556+
557+
let (thread_root, subscribe_up_to) = match up.update {
558+
RoomSendQueueUpdate::NewLocalEvent(local_echo) => {
559+
match local_echo.content {
560+
LocalEchoContent::Event { serialized_event, .. } => {
561+
if let Some(thread_root) = extract_thread_root(serialized_event) {
562+
events_being_sent.insert(local_echo.transaction_id, thread_root);
563+
}
564+
}
565+
LocalEchoContent::React { .. } => {
566+
// Nothing to do, reactions don't count as a thread
567+
// subscription.
568+
}
569+
}
570+
return true;
571+
}
572+
573+
RoomSendQueueUpdate::CancelledLocalEvent { transaction_id } => {
574+
events_being_sent.remove(&transaction_id);
575+
return true;
576+
}
577+
578+
RoomSendQueueUpdate::ReplacedLocalEvent { transaction_id, new_content } => {
579+
if let Some(thread_root) = extract_thread_root(new_content) {
580+
events_being_sent.insert(transaction_id, thread_root);
581+
} else {
582+
// It could be that the event isn't part of a thread anymore; handle that by
583+
// removing the pending transaction id.
584+
events_being_sent.remove(&transaction_id);
585+
}
586+
return true;
587+
}
588+
589+
RoomSendQueueUpdate::SentEvent { transaction_id, event_id } => {
590+
if let Some(thread_root) = events_being_sent.remove(&transaction_id) {
591+
(thread_root, event_id)
592+
} else {
593+
// We don't know about the event that has been sent, so ignore it.
594+
trace!(%transaction_id, "received a sent event that we didn't know about, ignoring");
595+
return true;
596+
}
597+
}
598+
599+
RoomSendQueueUpdate::SendError { .. }
600+
| RoomSendQueueUpdate::RetryEvent { .. }
601+
| RoomSendQueueUpdate::MediaUpload { .. } => {
602+
// Nothing to do for these bad boys.
603+
return true;
604+
}
605+
};
606+
607+
// And if we've found such a mention, subscribe to the thread up to this event.
608+
trace!(thread = %thread_root, up_to = %subscribe_up_to, "found a new thread to subscribe to");
609+
if let Err(err) = room.subscribe_thread_if_needed(&thread_root, Some(subscribe_up_to)).await
610+
{
611+
warn!(%err, "Failed to subscribe to thread");
612+
} else {
613+
let _ = thread_subscriber_sender.send(());
614+
}
615+
616+
true
512617
}
513618

514619
#[instrument(skip_all)]
@@ -517,16 +622,46 @@ impl EventCache {
517622
linked_chunk_update_sender: Sender<RoomEventCacheLinkedChunkUpdate>,
518623
thread_subscriber_sender: Sender<()>,
519624
) {
520-
if client.get().map_or(false, |client| !client.enabled_thread_subscriptions()) {
521-
trace!("Not spawning the thread subscriber task, because the client is shutting down or is not interested in those");
625+
let mut send_q_rx = if let Some(client) = client.get() {
626+
if !client.enabled_thread_subscriptions() {
627+
trace!("Thread subscriptions are not enabled, not spawning thread subscriber task");
628+
return;
629+
}
630+
631+
client.send_queue().subscribe()
632+
} else {
633+
trace!("Client is shutting down, not spawning thread subscriber task");
522634
return;
523-
}
635+
};
636+
637+
let mut linked_chunk_rx = linked_chunk_update_sender.subscribe();
524638

525-
let mut rx = linked_chunk_update_sender.subscribe();
639+
// A mapping of local echoes (events being sent), to their thread root, if
640+
// they're in an in-thread reply.
641+
//
642+
// Entirely managed by `handle_thread_subscriber_send_queue_update`.
643+
let mut events_being_sent = HashMap::new();
526644

527645
loop {
528646
select! {
529-
res = rx.recv() => {
647+
res = send_q_rx.recv() => {
648+
match res {
649+
Ok(up) => {
650+
if !Self::handle_thread_subscriber_send_queue_update(&client, &thread_subscriber_sender, &mut events_being_sent, up).await {
651+
break;
652+
}
653+
}
654+
Err(RecvError::Closed) => {
655+
debug!("Linked chunk update channel has been closed, exiting thread subscriber task");
656+
break;
657+
}
658+
Err(RecvError::Lagged(num_skipped)) => {
659+
warn!(num_skipped, "Lagged behind linked chunk updates");
660+
}
661+
}
662+
}
663+
664+
res = linked_chunk_rx.recv() => {
530665
match res {
531666
Ok(up) => {
532667
if !Self::handle_thread_subscriber_linked_chunk_update(&client, &thread_subscriber_sender, up).await {

crates/matrix-sdk/tests/integration/send_queue.rs

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use matrix_sdk::{
1414
RoomSendQueueStorageError, RoomSendQueueUpdate, SendHandle, SendQueueUpdate,
1515
},
1616
test_utils::mocks::{MatrixMock, MatrixMockServer},
17-
Client, MemoryStore,
17+
Client, MemoryStore, ThreadingSupport,
1818
};
1919
use matrix_sdk_test::{
2020
async_test, event_factory::EventFactory, InvitedRoomBuilder, KnockedRoomBuilder,
@@ -29,6 +29,7 @@ use ruma::{
2929
NewUnstablePollStartEventContent, UnstablePollAnswer, UnstablePollAnswers,
3030
UnstablePollStartContentBlock, UnstablePollStartEventContent,
3131
},
32+
relation::Thread,
3233
room::{
3334
message::{
3435
ImageMessageEventContent, MessageType, Relation, ReplyWithinThread,
@@ -3712,3 +3713,114 @@ async fn test_update_caption_while_sending_media_event() {
37123713
// That's all, folks!
37133714
assert!(watch.is_empty());
37143715
}
3716+
3717+
#[async_test]
3718+
async fn test_sending_reply_in_thread_auto_subscribe() {
3719+
let server = MatrixMockServer::new().await;
3720+
3721+
// Assuming a client that's interested in thread subscriptions,
3722+
let client = server
3723+
.client_builder()
3724+
.on_builder(|builder| {
3725+
builder.with_threading_support(ThreadingSupport::Enabled { with_subscriptions: true })
3726+
})
3727+
.build()
3728+
.await;
3729+
3730+
client.event_cache().subscribe().unwrap();
3731+
3732+
let room_id = room_id!("!a:b.c");
3733+
let room = server.sync_joined_room(&client, room_id).await;
3734+
3735+
server.mock_room_state_encryption().plain().mount().await;
3736+
3737+
// When I send a message to a thread,
3738+
let thread_root = event_id!("$thread");
3739+
3740+
let mut content = RoomMessageEventContent::text_plain("hello world");
3741+
content.relates_to =
3742+
Some(Relation::Thread(Thread::plain(thread_root.to_owned(), thread_root.to_owned())));
3743+
3744+
server.mock_room_send().ok(event_id!("$reply")).mock_once().mount().await;
3745+
3746+
server
3747+
.mock_put_thread_subscription()
3748+
.match_room_id(room_id.to_owned())
3749+
.match_thread_id(thread_root.to_owned())
3750+
.ok()
3751+
.mock_once()
3752+
.mount()
3753+
.await;
3754+
3755+
let (_, mut stream) = room.send_queue().subscribe().await.unwrap();
3756+
room.send_queue().send(content.into()).await.unwrap();
3757+
3758+
// Let the send queue process the event.
3759+
assert_let_timeout!(Ok(RoomSendQueueUpdate::NewLocalEvent(..)) = stream.recv());
3760+
assert_let_timeout!(Ok(RoomSendQueueUpdate::SentEvent { .. }) = stream.recv());
3761+
assert_let_timeout!(Ok(()) = thread_subscriber_updates.recv());
3762+
3763+
// Check the endpoints have been correctly called.
3764+
server.server().reset().await;
3765+
3766+
// Now, if I send a message in a thread I've already subscribed to, in automatic
3767+
// mode, this promotes the subscription to manual.
3768+
3769+
// Subscribed, automatically.
3770+
server
3771+
.mock_get_thread_subscription()
3772+
.match_room_id(room_id.to_owned())
3773+
.match_thread_id(thread_root.to_owned())
3774+
.ok(true)
3775+
.mount()
3776+
.await;
3777+
3778+
// I'll get one subscription.
3779+
server
3780+
.mock_put_thread_subscription()
3781+
.match_room_id(room_id.to_owned())
3782+
.match_thread_id(thread_root.to_owned())
3783+
.ok()
3784+
.mock_once()
3785+
.mount()
3786+
.await;
3787+
3788+
server.mock_room_send().ok(event_id!("$reply")).mock_once().mount().await;
3789+
3790+
let mut content = RoomMessageEventContent::text_plain("hello world");
3791+
content.relates_to =
3792+
Some(Relation::Thread(Thread::plain(thread_root.to_owned(), thread_root.to_owned())));
3793+
room.send_queue().send(content.into()).await.unwrap();
3794+
3795+
// Let the send queue process the event.
3796+
assert_let!(RoomSendQueueUpdate::NewLocalEvent(..) = stream.recv().await.unwrap());
3797+
assert_let!(RoomSendQueueUpdate::SentEvent { .. } = stream.recv().await.unwrap());
3798+
3799+
// Check the endpoints have been correctly called.
3800+
server.server().reset().await;
3801+
3802+
// Subscribed, but manually.
3803+
server
3804+
.mock_get_thread_subscription()
3805+
.match_room_id(room_id.to_owned())
3806+
.match_thread_id(thread_root.to_owned())
3807+
.ok(false)
3808+
.mount()
3809+
.await;
3810+
3811+
// I'll get zero subscription.
3812+
server.mock_put_thread_subscription().ok().expect(0).mount().await;
3813+
3814+
server.mock_room_send().ok(event_id!("$reply")).mock_once().mount().await;
3815+
3816+
let mut content = RoomMessageEventContent::text_plain("hello world");
3817+
content.relates_to =
3818+
Some(Relation::Thread(Thread::plain(thread_root.to_owned(), thread_root.to_owned())));
3819+
room.send_queue().send(content.into()).await.unwrap();
3820+
3821+
// Let the send queue process the event.
3822+
assert_let!(RoomSendQueueUpdate::NewLocalEvent(..) = stream.recv().await.unwrap());
3823+
assert_let!(RoomSendQueueUpdate::SentEvent { .. } = stream.recv().await.unwrap());
3824+
3825+
sleep(Duration::from_millis(100)).await;
3826+
}

0 commit comments

Comments
 (0)