1
- use std:: time:: Duration ;
1
+ use std:: { sync :: Arc , time:: Duration } ;
2
2
3
3
use crate :: {
4
4
error:: { ReceiveError , SendError , TypedReceiveError } ,
5
5
message:: { IncomingMessage , OutgoingMessage , PayloadTypeName , TypedIncomingMessage } ,
6
- traits:: { CommunicationBackend , CryptoProvider , SessionRepository } ,
6
+ traits:: {
7
+ CommunicationBackend , CommunicationBackendReceiver , CryptoProvider , SessionRepository ,
8
+ } ,
7
9
} ;
8
10
9
11
pub struct IpcClient < Crypto , Com , Ses >
@@ -17,45 +19,111 @@ where
17
19
sessions : Ses ,
18
20
}
19
21
22
+ /// A subscription to receive messages over IPC.
23
+ /// The subcription will start buffering messages after its creation and return them
24
+ /// when receive() is called. Messages received before the subscription was created will not be
25
+ /// returned.
26
+ pub struct IpcClientSubscription < Crypto , Com , Ses >
27
+ where
28
+ Crypto : CryptoProvider < Com , Ses > ,
29
+ Com : CommunicationBackend ,
30
+ Ses : SessionRepository < Session = Crypto :: Session > ,
31
+ {
32
+ receiver : Com :: Receiver ,
33
+ client : Arc < IpcClient < Crypto , Com , Ses > > ,
34
+ topic : Option < String > ,
35
+ }
36
+
37
+ /// A subscription to receive messages over IPC.
38
+ /// The subcription will start buffering messages after its creation and return them
39
+ /// when receive() is called. Messages received before the subscription was created will not be
40
+ /// returned.
41
+ pub struct IpcClientTypedSubscription < Crypto , Com , Ses , Payload >
42
+ where
43
+ Crypto : CryptoProvider < Com , Ses > ,
44
+ Com : CommunicationBackend ,
45
+ Ses : SessionRepository < Session = Crypto :: Session > ,
46
+ Payload : TryFrom < Vec < u8 > > + PayloadTypeName ,
47
+ {
48
+ receiver : Com :: Receiver ,
49
+ client : Arc < IpcClient < Crypto , Com , Ses > > ,
50
+ _payload : std:: marker:: PhantomData < Payload > ,
51
+ }
52
+
20
53
impl < Crypto , Com , Ses > IpcClient < Crypto , Com , Ses >
21
54
where
22
55
Crypto : CryptoProvider < Com , Ses > ,
23
56
Com : CommunicationBackend ,
24
57
Ses : SessionRepository < Session = Crypto :: Session > ,
25
58
{
26
- pub fn new ( crypto : Crypto , communication : Com , sessions : Ses ) -> Self {
27
- Self {
59
+ pub fn new ( crypto : Crypto , communication : Com , sessions : Ses ) -> Arc < Self > {
60
+ Arc :: new ( Self {
28
61
crypto,
29
62
communication,
30
63
sessions,
31
- }
64
+ } )
32
65
}
33
66
34
67
/// Send a message
35
68
pub async fn send (
36
- & self ,
69
+ self : & Arc < Self > ,
37
70
message : OutgoingMessage ,
38
71
) -> Result < ( ) , SendError < Crypto :: SendError , Com :: SendError > > {
39
72
self . crypto
40
73
. send ( & self . communication , & self . sessions , message)
41
74
. await
42
75
}
43
76
77
+ /// Create a subscription to receive messages, optionally filtered by topic.
78
+ /// Setting the topic to `None` will receive all messages.
79
+ pub async fn subscribe (
80
+ self : & Arc < Self > ,
81
+ topic : Option < String > ,
82
+ ) -> IpcClientSubscription < Crypto , Com , Ses > {
83
+ IpcClientSubscription {
84
+ receiver : self . communication . subscribe ( ) . await ,
85
+ client : self . clone ( ) ,
86
+ topic,
87
+ }
88
+ }
89
+
90
+ /// Create a subscription to receive messages that can be deserialized into the provided payload
91
+ /// type.
92
+ pub async fn subscribe_typed < Payload > (
93
+ self : & Arc < Self > ,
94
+ ) -> IpcClientTypedSubscription < Crypto , Com , Ses , Payload >
95
+ where
96
+ Payload : TryFrom < Vec < u8 > > + PayloadTypeName ,
97
+ {
98
+ IpcClientTypedSubscription {
99
+ receiver : self . communication . subscribe ( ) . await ,
100
+ client : self . clone ( ) ,
101
+ _payload : std:: marker:: PhantomData ,
102
+ }
103
+ }
104
+
44
105
/// Receive a message, optionally filtering by topic.
45
106
/// Setting the topic to `None` will receive all messages.
46
107
/// Setting the timeout to `None` will wait indefinitely.
47
- pub async fn receive (
108
+ async fn receive (
48
109
& self ,
49
- topic : Option < String > ,
110
+ receiver : & Com :: Receiver ,
111
+ topic : & Option < String > ,
50
112
timeout : Option < Duration > ,
51
- ) -> Result < IncomingMessage , ReceiveError < Crypto :: ReceiveError , Com :: ReceiveError > > {
113
+ ) -> Result <
114
+ IncomingMessage ,
115
+ ReceiveError <
116
+ Crypto :: ReceiveError ,
117
+ <Com :: Receiver as CommunicationBackendReceiver >:: ReceiveError ,
118
+ > ,
119
+ > {
52
120
let receive_loop = async {
53
121
loop {
54
122
let received = self
55
123
. crypto
56
- . receive ( & self . communication , & self . sessions )
124
+ . receive ( receiver , & self . communication , & self . sessions )
57
125
. await ?;
58
- if topic. is_none ( ) || received. topic == topic {
126
+ if topic. is_none ( ) || & received. topic == topic {
59
127
return Ok ( received) ;
60
128
}
61
129
}
@@ -72,26 +140,75 @@ where
72
140
73
141
/// Receive a message, skipping any messages that cannot be deserialized into the expected
74
142
/// payload type.
75
- pub async fn receive_typed < Payload > (
143
+ async fn receive_typed < Payload > (
76
144
& self ,
145
+ receiver : & Com :: Receiver ,
77
146
timeout : Option < Duration > ,
78
147
) -> Result <
79
148
TypedIncomingMessage < Payload > ,
80
149
TypedReceiveError <
81
150
<Payload as TryFrom < Vec < u8 > > >:: Error ,
82
151
Crypto :: ReceiveError ,
83
- Com :: ReceiveError ,
152
+ < Com :: Receiver as CommunicationBackendReceiver > :: ReceiveError ,
84
153
> ,
85
154
>
86
155
where
87
156
Payload : TryFrom < Vec < u8 > > + PayloadTypeName ,
88
157
{
89
158
let topic = Some ( Payload :: name ( ) ) ;
90
- let received = self . receive ( topic, timeout) . await ?;
159
+ let received = self . receive ( receiver , & topic, timeout) . await ?;
91
160
received. try_into ( ) . map_err ( TypedReceiveError :: Typing )
92
161
}
93
162
}
94
163
164
+ impl < Crypto , Com , Ses > IpcClientSubscription < Crypto , Com , Ses >
165
+ where
166
+ Crypto : CryptoProvider < Com , Ses > ,
167
+ Com : CommunicationBackend ,
168
+ Ses : SessionRepository < Session = Crypto :: Session > ,
169
+ {
170
+ /// Receive a message, optionally filtering by topic.
171
+ /// Setting the timeout to `None` will wait indefinitely.
172
+ pub async fn receive (
173
+ & self ,
174
+ timeout : Option < Duration > ,
175
+ ) -> Result <
176
+ IncomingMessage ,
177
+ ReceiveError <
178
+ Crypto :: ReceiveError ,
179
+ <Com :: Receiver as CommunicationBackendReceiver >:: ReceiveError ,
180
+ > ,
181
+ > {
182
+ self . client
183
+ . receive ( & self . receiver , & self . topic , timeout)
184
+ . await
185
+ }
186
+ }
187
+
188
+ impl < Crypto , Com , Ses , Payload > IpcClientTypedSubscription < Crypto , Com , Ses , Payload >
189
+ where
190
+ Crypto : CryptoProvider < Com , Ses > ,
191
+ Com : CommunicationBackend ,
192
+ Ses : SessionRepository < Session = Crypto :: Session > ,
193
+ Payload : TryFrom < Vec < u8 > > + PayloadTypeName ,
194
+ {
195
+ /// Receive a message.
196
+ /// Setting the timeout to `None` will wait indefinitely.
197
+ pub async fn receive (
198
+ & self ,
199
+ timeout : Option < Duration > ,
200
+ ) -> Result <
201
+ TypedIncomingMessage < Payload > ,
202
+ TypedReceiveError <
203
+ <Payload as TryFrom < Vec < u8 > > >:: Error ,
204
+ Crypto :: ReceiveError ,
205
+ <Com :: Receiver as CommunicationBackendReceiver >:: ReceiveError ,
206
+ > ,
207
+ > {
208
+ self . client . receive_typed ( & self . receiver , timeout) . await
209
+ }
210
+ }
211
+
95
212
#[ cfg( test) ]
96
213
mod tests {
97
214
use std:: collections:: HashMap ;
@@ -121,6 +238,7 @@ mod tests {
121
238
122
239
async fn receive (
123
240
& self ,
241
+ _receiver : & <TestCommunicationBackend as CommunicationBackend >:: Receiver ,
124
242
_communication : & TestCommunicationBackend ,
125
243
_sessions : & TestSessionRepository ,
126
244
) -> Result < IncomingMessage , ReceiveError < String , TestCommunicationBackendReceiveError > >
@@ -176,7 +294,8 @@ mod tests {
176
294
let session_map = TestSessionRepository :: new ( HashMap :: new ( ) ) ;
177
295
let client = IpcClient :: new ( crypto_provider, communication_provider, session_map) ;
178
296
179
- let error = client. receive ( None , None ) . await . unwrap_err ( ) ;
297
+ let subscription = client. subscribe ( None ) . await ;
298
+ let error = subscription. receive ( None ) . await . unwrap_err ( ) ;
180
299
181
300
assert_eq ! ( error, ReceiveError :: Crypto ( "Crypto error" . to_string( ) ) ) ;
182
301
}
@@ -212,8 +331,9 @@ mod tests {
212
331
let session_map = InMemorySessionRepository :: new ( HashMap :: new ( ) ) ;
213
332
let client = IpcClient :: new ( crypto_provider, communication_provider. clone ( ) , session_map) ;
214
333
215
- communication_provider. push_incoming ( message. clone ( ) ) . await ;
216
- let received_message = client. receive ( None , None ) . await . unwrap ( ) ;
334
+ let subscription = & client. subscribe ( None ) . await ;
335
+ communication_provider. push_incoming ( message. clone ( ) ) ;
336
+ let received_message = subscription. receive ( None ) . await . unwrap ( ) ;
217
337
218
338
assert_eq ! ( received_message, message) ;
219
339
}
@@ -237,20 +357,12 @@ mod tests {
237
357
let communication_provider = TestCommunicationBackend :: new ( ) ;
238
358
let session_map = InMemorySessionRepository :: new ( HashMap :: new ( ) ) ;
239
359
let client = IpcClient :: new ( crypto_provider, communication_provider. clone ( ) , session_map) ;
240
- communication_provider
241
- . push_incoming ( non_matching_message. clone ( ) )
242
- . await ;
243
- communication_provider
244
- . push_incoming ( non_matching_message. clone ( ) )
245
- . await ;
246
- communication_provider
247
- . push_incoming ( matching_message. clone ( ) )
248
- . await ;
249
-
250
- let received_message: IncomingMessage = client
251
- . receive ( Some ( "matching_topic" . to_owned ( ) ) , None )
252
- . await
253
- . unwrap ( ) ;
360
+ let subscription = client. subscribe ( Some ( "matching_topic" . to_owned ( ) ) ) . await ;
361
+ communication_provider. push_incoming ( non_matching_message. clone ( ) ) ;
362
+ communication_provider. push_incoming ( non_matching_message. clone ( ) ) ;
363
+ communication_provider. push_incoming ( matching_message. clone ( ) ) ;
364
+
365
+ let received_message: IncomingMessage = subscription. receive ( None ) . await . unwrap ( ) ;
254
366
255
367
assert_eq ! ( received_message, matching_message) ;
256
368
}
@@ -302,18 +414,12 @@ mod tests {
302
414
let communication_provider = TestCommunicationBackend :: new ( ) ;
303
415
let session_map = InMemorySessionRepository :: new ( HashMap :: new ( ) ) ;
304
416
let client = IpcClient :: new ( crypto_provider, communication_provider. clone ( ) , session_map) ;
305
- communication_provider
306
- . push_incoming ( unrelated. clone ( ) )
307
- . await ;
308
- communication_provider
309
- . push_incoming ( unrelated. clone ( ) )
310
- . await ;
311
- communication_provider
312
- . push_incoming ( typed_message. clone ( ) . try_into ( ) . unwrap ( ) )
313
- . await ;
314
-
315
- let received_message: TypedIncomingMessage < TestPayload > =
316
- client. receive_typed ( None ) . await . unwrap ( ) ;
417
+ let subscription = client. subscribe_typed :: < TestPayload > ( ) . await ;
418
+ communication_provider. push_incoming ( unrelated. clone ( ) ) ;
419
+ communication_provider. push_incoming ( unrelated. clone ( ) ) ;
420
+ communication_provider. push_incoming ( typed_message. clone ( ) . try_into ( ) . unwrap ( ) ) ;
421
+
422
+ let received_message = subscription. receive ( None ) . await . unwrap ( ) ;
317
423
318
424
assert_eq ! ( received_message, typed_message) ;
319
425
}
@@ -358,11 +464,10 @@ mod tests {
358
464
let communication_provider = TestCommunicationBackend :: new ( ) ;
359
465
let session_map = InMemorySessionRepository :: new ( HashMap :: new ( ) ) ;
360
466
let client = IpcClient :: new ( crypto_provider, communication_provider. clone ( ) , session_map) ;
361
- communication_provider
362
- . push_incoming ( non_deserializable_message. clone ( ) )
363
- . await ;
467
+ let subscription = client. subscribe_typed :: < TestPayload > ( ) . await ;
468
+ communication_provider. push_incoming ( non_deserializable_message. clone ( ) ) ;
364
469
365
- let result: Result < TypedIncomingMessage < TestPayload > , _ > = client . receive_typed ( None ) . await ;
470
+ let result = subscription . receive ( None ) . await ;
366
471
367
472
assert ! ( matches!(
368
473
result,
0 commit comments