-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathsocket.rs
More file actions
299 lines (246 loc) · 9.89 KB
/
socket.rs
File metadata and controls
299 lines (246 loc) · 9.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
use std::{
collections::HashSet,
net::SocketAddr,
path::PathBuf,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures::Stream;
use rustc_hash::FxHashMap;
use tokio::{
net::{ToSocketAddrs, lookup_host},
sync::mpsc,
};
use msg_common::{IpAddrExt, JoinMap};
use msg_transport::{Address, Transport};
// ADDED: Import the specific SubStats struct for the API
use super::stats::SubStats;
// Import the rest from the parent module (sub/mod.rs)
use super::{
// REMOVED: Old/removed stats structs
// Command, PubMessage, SocketState, SocketStats, SocketWideStats, SubDriver, SubError,
Command,
DEFAULT_BUFFER_SIZE,
PubMessage,
SocketState,
SubDriver,
SubError,
SubOptions,
};
/// A subscriber socket. This socket implements [`Stream`] and yields incoming [`PubMessage`]s.
pub struct SubSocket<T: Transport<A>, A: Address> {
/// Command channel to the socket driver.
to_driver: mpsc::Sender<Command<A>>,
/// Receiver channel from the socket driver.
from_driver: mpsc::Receiver<PubMessage<A>>,
/// Options for the socket. These are shared with the backend task.
#[allow(unused)]
options: Arc<SubOptions>,
/// The pending driver.
driver: Option<SubDriver<T, A>>,
/// Socket state. This is shared with the socket frontend. Contains the unified stats.
state: Arc<SocketState<A>>,
/// Marker for the transport type.
_marker: std::marker::PhantomData<T>,
}
impl<T> SubSocket<T, SocketAddr>
where
T: Transport<SocketAddr> + Send + Sync + Unpin + 'static,
{
/// Connects to the given endpoint asynchronously.
pub async fn connect(&mut self, endpoint: impl ToSocketAddrs) -> Result<(), SubError> {
let mut addrs = lookup_host(endpoint).await?;
let mut endpoint = addrs.next().ok_or(SubError::NoValidEndpoints)?;
// Some transport implementations (e.g. Quinn) can't dial an unspecified
// IP address, so replace it with localhost.
if endpoint.ip().is_unspecified() {
endpoint.set_ip(endpoint.ip().as_localhost());
}
self.connect_inner(endpoint).await
}
/// Attempts to connect to the given endpoint immediately.
pub fn try_connect(&mut self, endpoint: impl Into<String>) -> Result<(), SubError> {
let addr = endpoint.into();
let mut endpoint: SocketAddr = addr.parse().map_err(|_| SubError::NoValidEndpoints)?;
// Some transport implementations (e.g. Quinn) can't dial an unspecified
// IP address, so replace it with localhost.
if endpoint.ip().is_unspecified() {
endpoint.set_ip(endpoint.ip().as_localhost());
}
self.try_connect_inner(endpoint)
}
/// Disconnects from the given endpoint asynchronously.
pub async fn disconnect(&mut self, endpoint: impl ToSocketAddrs) -> Result<(), SubError> {
let mut addrs = lookup_host(endpoint).await?;
let endpoint = addrs.next().ok_or(SubError::NoValidEndpoints)?;
self.disconnect_inner(endpoint).await
}
/// Attempts to disconnect from the given endpoint immediately.
pub fn try_disconnect(&mut self, endpoint: impl Into<String>) -> Result<(), SubError> {
let endpoint = endpoint.into();
let endpoint: SocketAddr = endpoint.parse().map_err(|_| SubError::NoValidEndpoints)?;
self.try_disconnect_inner(endpoint)
}
}
impl<T> SubSocket<T, PathBuf>
where
T: Transport<PathBuf> + Send + Sync + Unpin + 'static,
{
/// Connects to the given path asynchronously.
pub async fn connect_path(&mut self, path: impl Into<PathBuf>) -> Result<(), SubError> {
self.connect_inner(path.into()).await
}
/// Attempts to connect to the given path immediately.
pub fn try_connect_path(&mut self, path: impl Into<PathBuf>) -> Result<(), SubError> {
self.try_connect_inner(path.into())
}
/// Disconnects from the given path asynchronously.
pub async fn disconnect_path(&mut self, path: impl Into<PathBuf>) -> Result<(), SubError> {
self.disconnect_inner(path.into()).await
}
/// Attempts to disconnect from the given path immediately.
pub fn try_disconnect_path(&mut self, path: impl Into<PathBuf>) -> Result<(), SubError> {
self.try_disconnect_inner(path.into())
}
}
impl<T, A> SubSocket<T, A>
where
T: Transport<A> + Send + Sync + Unpin + 'static,
A: Address,
{
/// Creates a new subscriber socket with the default [`SubOptions`].
pub fn new(transport: T) -> Self {
Self::with_options(transport, SubOptions::default())
}
/// Creates a new subscriber socket with the given transport and options.
pub fn with_options(transport: T, options: SubOptions) -> Self {
let (to_driver, from_socket) = mpsc::channel(DEFAULT_BUFFER_SIZE);
let (to_socket, from_driver) = mpsc::channel(options.ingress_buffer_size);
let options = Arc::new(options);
let state = Arc::new(SocketState::default()); // SocketState uses default
let mut publishers = FxHashMap::default();
publishers.reserve(32);
let driver = SubDriver {
options: Arc::clone(&options),
transport,
from_socket,
to_socket,
connection_tasks: JoinMap::new(),
publishers,
subscribed_topics: HashSet::with_capacity(32),
state: Arc::clone(&state),
};
Self {
to_driver,
from_driver,
driver: Some(driver),
options,
state,
_marker: std::marker::PhantomData,
}
}
/// Asynchronously connects to the endpoint.
pub async fn connect_inner(&mut self, endpoint: A) -> Result<(), SubError> {
self.ensure_active_driver();
self.send_command(Command::Connect { endpoint }).await?;
Ok(())
}
/// Immediately send a connect command to the driver.
pub fn try_connect_inner(&mut self, endpoint: A) -> Result<(), SubError> {
self.ensure_active_driver();
self.try_send_command(Command::Connect { endpoint })?;
Ok(())
}
/// Asynchronously disconnects from the endpoint.
pub async fn disconnect_inner(&mut self, endpoint: A) -> Result<(), SubError> {
self.ensure_active_driver();
self.send_command(Command::Disconnect { endpoint }).await?;
Ok(())
}
/// Immediately send a disconnect command to the driver.
pub fn try_disconnect_inner(&mut self, endpoint: A) -> Result<(), SubError> {
self.ensure_active_driver();
self.try_send_command(Command::Disconnect { endpoint })?;
Ok(())
}
/// Subscribes to the given topic. This will subscribe to all connected publishers.
/// If the topic does not exist on a publisher, this will not return any data.
/// Any publishers that are connected after this call will also be subscribed to.
pub async fn subscribe(&mut self, topic: impl Into<String>) -> Result<(), SubError> {
self.ensure_active_driver();
let topic = topic.into();
if topic.starts_with("MSG") {
return Err(SubError::ReservedTopic);
}
self.send_command(Command::Subscribe { topic }).await?;
Ok(())
}
/// Immediately send a subscribe command to the driver.
pub fn try_subscribe(&mut self, topic: impl Into<String>) -> Result<(), SubError> {
self.ensure_active_driver();
let topic = topic.into();
if topic.starts_with("MSG") {
return Err(SubError::ReservedTopic);
}
self.try_send_command(Command::Subscribe { topic })?;
Ok(())
}
/// Unsubscribe from the given topic. This will unsubscribe from all connected publishers.
pub async fn unsubscribe(&mut self, topic: impl Into<String>) -> Result<(), SubError> {
self.ensure_active_driver();
let topic = topic.into();
if topic.starts_with("MSG") {
return Err(SubError::ReservedTopic);
}
self.send_command(Command::Unsubscribe { topic }).await?;
Ok(())
}
/// Immediately send an unsubscribe command to the driver.
pub fn try_unsubscribe(&mut self, topic: impl Into<String>) -> Result<(), SubError> {
self.ensure_active_driver();
let topic = topic.into();
if topic.starts_with("MSG") {
return Err(SubError::ReservedTopic);
}
self.try_send_command(Command::Unsubscribe { topic })?;
Ok(())
}
/// Sends a command to the driver, returning [`SubError::SocketClosed`] if the
/// driver has been dropped.
async fn send_command(&self, command: Command<A>) -> Result<(), SubError> {
self.to_driver.send(command).await.map_err(|_| SubError::SocketClosed)?;
Ok(())
}
fn try_send_command(&self, command: Command<A>) -> Result<(), SubError> {
use mpsc::error::TrySendError::*;
self.to_driver.try_send(command).map_err(|e| match e {
Full(_) => SubError::ChannelFull,
Closed(_) => SubError::SocketClosed,
})?;
Ok(())
}
/// Ensures that the driver task is running. This function will be called on every command,
/// which might be overkill, but it keeps the interface simple and is not in the hot path.
fn ensure_active_driver(&mut self) {
if let Some(driver) = self.driver.take() {
tokio::spawn(driver);
}
}
/// Returns the statistics specific to the subscriber socket.
pub fn stats(&self) -> &SubStats<A> {
&self.state.stats.specific
}
}
impl<T: Transport<A>, A: Address> Drop for SubSocket<T, A> {
fn drop(&mut self) {
// Try to tell the driver to gracefully shut down.
let _ = self.to_driver.try_send(Command::Shutdown);
}
}
impl<T: Transport<A> + Unpin, A: Address> Stream for SubSocket<T, A> {
type Item = PubMessage<A>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.from_driver.poll_recv(cx)
}
}