Skip to content

Commit 704dbc9

Browse files
committed
authmailbox: add multi subscriber helper type
This MultiSubscription helper struct allows us to subscribe to receive messages for multiple keys held by a receiving wallet but all consolidated into a single message channel.
1 parent 6e03ce4 commit 704dbc9

File tree

4 files changed

+243
-2
lines changed

4 files changed

+243
-2
lines changed

authmailbox/client_test.go

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"context"
66
"fmt"
7+
"net/url"
78
"os"
89
"testing"
910
"time"
@@ -236,6 +237,40 @@ func TestServerClientAuthAndRestart(t *testing.T) {
236237
client2.stop(t)
237238
})
238239

240+
// We also add a multi-subscription to the same two keys, so we can make
241+
// sure we can receive messages from multiple clients at once.
242+
multiSub := NewMultiSubscription(*clientCfg)
243+
err := multiSub.Subscribe(
244+
ctx, url.URL{Host: clientCfg.ServerAddress}, clientKey1, filter,
245+
)
246+
require.NoError(t, err)
247+
err = multiSub.Subscribe(
248+
ctx, url.URL{Host: clientCfg.ServerAddress}, clientKey2, filter,
249+
)
250+
require.NoError(t, err)
251+
t.Cleanup(func() {
252+
require.NoError(t, multiSub.Stop())
253+
})
254+
msgChan := multiSub.MessageChan()
255+
readMultiSub := func(targetID ...uint64) {
256+
t.Helper()
257+
select {
258+
case inboundMsgs := <-msgChan:
259+
receivedIDs := fn.Map(
260+
inboundMsgs.Messages,
261+
func(msg *mboxrpc.MailboxMessage) uint64 {
262+
return msg.MessageId
263+
},
264+
)
265+
for _, target := range targetID {
266+
require.Contains(t, receivedIDs, target)
267+
}
268+
case <-time.After(testTimeout):
269+
t.Fatalf("timeout waiting for message with ID %v",
270+
targetID)
271+
}
272+
}
273+
239274
// Send a message to all clients.
240275
msg1 := &Message{
241276
ID: 1000,
@@ -244,14 +279,15 @@ func TestServerClientAuthAndRestart(t *testing.T) {
244279
}
245280

246281
// We also store the message in the store, so we can retrieve it later.
247-
_, err := harness.mockMsgStore.StoreMessage(ctx, randOp, msg1)
282+
_, err = harness.mockMsgStore.StoreMessage(ctx, randOp, msg1)
248283
require.NoError(t, err)
249284

250285
harness.srv.publishMessage(msg1)
251286

252287
// We should be able to receive that message.
253288
client1.readMessages(t, msg1.ID)
254289
client2.readMessages(t, msg1.ID)
290+
readMultiSub(msg1.ID)
255291

256292
// We now stop the server and assert that the subscription is no longer
257293
// active.
@@ -282,6 +318,7 @@ func TestServerClientAuthAndRestart(t *testing.T) {
282318
// We should be able to receive that message.
283319
client1.readMessages(t, msg2.ID)
284320
client2.readMessages(t, msg2.ID)
321+
readMultiSub(msg2.ID)
285322

286323
// If we now start a third client, we should be able to receive all
287324
// three messages, given we are using the same key and specify the
@@ -314,6 +351,23 @@ func TestServerClientAuthAndRestart(t *testing.T) {
314351
harness.srv.publishMessage(msg3)
315352
client4.expectNoMessage(t)
316353
client1.readMessages(t, msg3.ID)
354+
client2.readMessages(t, msg3.ID)
355+
client3.readMessages(t, msg3.ID)
356+
readMultiSub(msg3.ID)
357+
358+
// Let's make sure that a message sent to the second key is only
359+
// received by the fourth client and the multi-subscription.
360+
msg4 := &Message{
361+
ID: 1001,
362+
ReceiverKey: *clientKey2.PubKey,
363+
ArrivalTimestamp: time.Now(),
364+
}
365+
harness.srv.publishMessage(msg4)
366+
client1.expectNoMessage(t)
367+
client2.expectNoMessage(t)
368+
client3.expectNoMessage(t)
369+
client4.readMessages(t, msg4.ID)
370+
readMultiSub(msg4.ID)
317371
}
318372

319373
// TestSendMessage tests the SendMessage RPC of the server and its ability to

authmailbox/mock.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,8 @@ func (s *MockMsgStore) QueryMessages(_ context.Context,
8787
}
8888

8989
func (s *MockMsgStore) NumMessages(context.Context) uint64 {
90+
s.mu.Lock()
91+
defer s.mu.Unlock()
92+
9093
return uint64(len(s.messages))
9194
}

authmailbox/multi_subscription.go

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
package authmailbox
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/url"
7+
"sync"
8+
9+
"github.com/lightninglabs/taproot-assets/asset"
10+
lfn "github.com/lightningnetwork/lnd/fn/v2"
11+
"github.com/lightningnetwork/lnd/keychain"
12+
)
13+
14+
// clientSubscriptions holds the subscriptions and cancel functions for a
15+
// specific mailbox client.
16+
type clientSubscriptions struct {
17+
// client is the mailbox client that this subscription belongs to.
18+
client *Client
19+
20+
// subscriptions holds the active subscriptions for this client, keyed
21+
// by the serialized public key of the receiver.
22+
subscriptions map[asset.SerializedKey]ReceiveSubscription
23+
24+
// cancels holds the cancel functions for each subscription, also keyed
25+
// by the serialized public key of the receiver.
26+
cancels map[asset.SerializedKey]context.CancelFunc
27+
}
28+
29+
// MultiSubscription is a subscription manager that can handle multiple mailbox
30+
// clients, allowing subscriptions to different accounts across different
31+
// mailbox servers. It manages subscriptions and message queues for each client
32+
// and provides a unified interface for receiving messages.
33+
type MultiSubscription struct {
34+
// baseClientConfig holds the basic configuration for the mailbox
35+
// clients. All fields except the ServerAddress are used to create
36+
// new mailbox clients when needed.
37+
baseClientConfig ClientConfig
38+
39+
// clients holds the active mailbox clients, keyed by their server URL.
40+
clients map[url.URL]*clientSubscriptions
41+
42+
// msgQueue is the concurrent queue that holds received messages from
43+
// all subscriptions across all clients. This allows for a unified
44+
// message channel that can be used to receive messages from any
45+
// subscribed account, regardless of which mailbox server it belongs to.
46+
msgQueue *lfn.ConcurrentQueue[*ReceivedMessages]
47+
48+
sync.RWMutex
49+
}
50+
51+
// NewMultiSubscription creates a new MultiSubscription instance.
52+
func NewMultiSubscription(baseClientConfig ClientConfig) *MultiSubscription {
53+
queue := lfn.NewConcurrentQueue[*ReceivedMessages](lfn.DefaultQueueSize)
54+
queue.Start()
55+
56+
return &MultiSubscription{
57+
baseClientConfig: baseClientConfig,
58+
clients: make(map[url.URL]*clientSubscriptions),
59+
msgQueue: queue,
60+
}
61+
}
62+
63+
// Subscribe adds a new subscription for the specified client URL and receiver
64+
// key. It starts a new mailbox client if one does not already exist for the
65+
// given URL. The subscription will receive messages that match the provided
66+
// filter and will send them to the shared message queue.
67+
func (m *MultiSubscription) Subscribe(ctx context.Context, serverURL url.URL,
68+
receiverKey keychain.KeyDescriptor, filter MessageFilter) error {
69+
70+
// We hold the mutex for access to common resources.
71+
m.Lock()
72+
cfgCopy := m.baseClientConfig
73+
client, ok := m.clients[serverURL]
74+
75+
// If this is the first time we're seeing a server URL, we first create
76+
// a network connection to the mailbox server.
77+
if !ok {
78+
cfgCopy.ServerAddress = serverURL.Host
79+
80+
mboxClient := NewClient(&cfgCopy)
81+
client = &clientSubscriptions{
82+
client: mboxClient,
83+
subscriptions: make(
84+
map[asset.SerializedKey]ReceiveSubscription,
85+
),
86+
cancels: make(
87+
map[asset.SerializedKey]context.CancelFunc,
88+
),
89+
}
90+
m.clients[serverURL] = client
91+
92+
err := mboxClient.Start()
93+
if err != nil {
94+
m.Unlock()
95+
return fmt.Errorf("unable to create mailbox client: %w",
96+
err)
97+
}
98+
}
99+
100+
// We release the lock here again, because StartAccountSubscription
101+
// might block for a while, and we don't want to hold the lock
102+
// unnecessarily long.
103+
m.Unlock()
104+
105+
ctx, cancel := context.WithCancel(ctx)
106+
subscription, err := client.client.StartAccountSubscription(
107+
ctx, m.msgQueue.ChanIn(), receiverKey, filter,
108+
)
109+
if err != nil {
110+
cancel()
111+
return fmt.Errorf("unable to start mailbox subscription: %w",
112+
err)
113+
}
114+
115+
// We hold the lock again to safely add the subscription and cancel
116+
// function to the client's maps.
117+
m.Lock()
118+
key := asset.ToSerialized(receiverKey.PubKey)
119+
client.subscriptions[key] = subscription
120+
client.cancels[key] = cancel
121+
m.Unlock()
122+
123+
return nil
124+
}
125+
126+
// MessageChan returns a channel that can be used to receive messages from all
127+
// subscriptions across all mailbox clients. This channel will receive
128+
// ReceivedMessages, which contain the messages and their associated
129+
// metadata, such as the sender and receiver keys.
130+
func (m *MultiSubscription) MessageChan() <-chan *ReceivedMessages {
131+
return m.msgQueue.ChanOut()
132+
}
133+
134+
// Stop stops all active subscriptions and mailbox clients. It cancels all
135+
// active subscription contexts and waits for all clients to stop gracefully.
136+
func (m *MultiSubscription) Stop() error {
137+
defer m.msgQueue.Stop()
138+
139+
log.Info("Stopping all mailbox clients and subscriptions...")
140+
141+
m.RLock()
142+
defer m.RUnlock()
143+
144+
var lastErr error
145+
for _, client := range m.clients {
146+
for _, cancel := range client.cancels {
147+
cancel()
148+
}
149+
150+
for _, sub := range client.subscriptions {
151+
err := sub.Stop()
152+
if err != nil {
153+
log.Errorf("Error stopping subscription: %v",
154+
err)
155+
lastErr = err
156+
}
157+
}
158+
159+
if err := client.client.Stop(); err != nil {
160+
log.Errorf("Error stopping client: %v", err)
161+
lastErr = err
162+
}
163+
}
164+
165+
return lastErr
166+
}

authmailbox/receive_subscription.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,18 @@ func (s *receiveSubscription) connectServerStream(ctx context.Context,
200200
err error
201201
)
202202
for i := 0; i < numRetries; i++ {
203+
// If we're shutting down, we don't want to re-try connecting.
204+
select {
205+
case <-s.quit:
206+
log.DebugS(ctx, "Client is shutting down...")
207+
return ErrClientShutdown
208+
209+
case <-ctx.Done():
210+
log.DebugS(ctx, "Client is shutting down...")
211+
return ErrClientShutdown
212+
default:
213+
}
214+
203215
// Wait before connecting in case this is a re-connect trial.
204216
if backoff != 0 {
205217
err = s.wait(backoff)
@@ -441,16 +453,22 @@ func (s *receiveSubscription) HandleServerShutdown(ctx context.Context,
441453

442454
// closeStream closes the long-lived stream connection to the server.
443455
func (s *receiveSubscription) closeStream(ctx context.Context) error {
456+
log.InfoS(ctx, "Closing stream")
457+
444458
s.streamMutex.Lock()
445459
defer s.streamMutex.Unlock()
446460

461+
if s.streamCancel != nil {
462+
s.streamCancel()
463+
}
464+
447465
if s.serverStream == nil {
466+
log.InfoS(ctx, "Server stream is not connected")
448467
return nil
449468
}
450469

451470
log.DebugS(ctx, "Closing server stream")
452471
err := s.serverStream.CloseSend()
453-
s.streamCancel()
454472
s.serverStream = nil
455473

456474
return err

0 commit comments

Comments
 (0)