Skip to content

Commit c0c5728

Browse files
committed
refactor(transport): replace watch channel with rwlock
1 parent 40c0161 commit c0c5728

File tree

4 files changed

+16
-25
lines changed

4 files changed

+16
-25
lines changed

msg-socket/src/req/driver.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,7 @@ where
279279
if let Ok(io) = result {
280280
tracing::debug!(target = ?io.peer_addr(), "new connection");
281281

282-
let tx = this.socket_state.transport.0.clone();
283-
let metered = MeteredIo::new(io, tx);
282+
let metered = MeteredIo::new(io, Arc::clone(&this.socket_state.transport));
284283

285284
let mut framed = Framed::new(metered, reqrep::Codec::new());
286285
framed.set_backpressure_boundary(this.options.backpressure_boundary);

msg-socket/src/req/mod.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use bytes::Bytes;
2-
use std::{sync::Arc, time::Duration};
2+
use std::{
3+
sync::{Arc, RwLock},
4+
time::Duration,
5+
};
36
use thiserror::Error;
4-
use tokio::sync::{oneshot, watch};
7+
use tokio::sync::oneshot;
58

69
use msg_wire::{
710
compression::{CompressionType, Compressor},
@@ -185,8 +188,7 @@ impl ReqMessage {
185188
pub(crate) struct SocketState<S> {
186189
/// The socket stats.
187190
pub(crate) stats: Arc<SocketStats<ReqStats>>,
188-
/// The transport stats. This is None until a connection is established.
189-
pub(crate) transport: (watch::Sender<Arc<S>>, watch::Receiver<Arc<S>>),
191+
pub(crate) transport: Arc<RwLock<Arc<S>>>,
190192
}
191193

192194
// Manual clone implementation needed here because `S` is n`.

msg-socket/src/req/socket.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@ use rustc_hash::FxHashMap;
33
use std::{marker::PhantomData, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration};
44
use tokio::{
55
net::{ToSocketAddrs, lookup_host},
6-
sync::{
7-
mpsc, oneshot,
8-
watch::{self},
9-
},
6+
sync::{mpsc, oneshot},
107
};
118

129
use msg_transport::{Address, Transport};
@@ -76,7 +73,7 @@ where
7673
options: Arc::new(options),
7774
state: SocketState {
7875
stats: Arc::new(SocketStats::default()),
79-
transport: watch::channel(Arc::new(T::Stats::default())),
76+
transport: Default::default(),
8077
},
8178
compressor: None,
8279
_marker: PhantomData,
@@ -96,9 +93,7 @@ where
9693

9794
/// Borrow the latest transport-level stats snapshot.
9895
pub fn transport_stats(&self) -> Arc<T::Stats> {
99-
// NOTE: We clone the Arc here because purely borrowing the inner stats
100-
// would lock the channel.
101-
self.state.transport.1.borrow().clone()
96+
Arc::clone(&self.state.transport.read().unwrap())
10297
}
10398

10499
pub async fn request(&self, message: Bytes) -> Result<Bytes, ReqError> {

msg-transport/src/lib.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,14 @@ use std::{
1010
net::SocketAddr,
1111
path::PathBuf,
1212
pin::Pin,
13-
sync::Arc,
13+
sync::{Arc, RwLock},
1414
task::{Context, Poll},
1515
time::{Duration, Instant},
1616
};
1717

1818
use async_trait::async_trait;
1919
use futures::{Future, FutureExt};
20-
use tokio::{
21-
io::{AsyncRead, AsyncWrite},
22-
sync::watch,
23-
};
20+
use tokio::io::{AsyncRead, AsyncWrite};
2421

2522
pub mod ipc;
2623
#[cfg(feature = "quic")]
@@ -47,7 +44,7 @@ where
4744
/// The inner IO object.
4845
inner: Io,
4946
/// The sender for the stats.
50-
sender: watch::Sender<Arc<S>>,
47+
stats: Arc<RwLock<Arc<S>>>,
5148
/// The next time the stats should be refreshed.
5249
next_refresh: Instant,
5350
/// The interval at which the stats should be refreshed.
@@ -130,10 +127,10 @@ where
130127
/// stats. The `sender` is used to send the latest stats to the caller.
131128
///
132129
/// TODO: Specify configuration options.
133-
pub fn new(inner: Io, sender: watch::Sender<Arc<S>>) -> Self {
130+
pub fn new(inner: Io, stats: Arc<RwLock<Arc<S>>>) -> Self {
134131
Self {
135132
inner,
136-
sender,
133+
stats,
137134
_marker: PhantomData,
138135
next_refresh: Instant::now(),
139136
refresh_interval: Duration::from_secs(2),
@@ -146,9 +143,7 @@ where
146143
if self.next_refresh <= now {
147144
match S::try_from(&self.inner) {
148145
Ok(stats) => {
149-
if let Err(e) = self.sender.send(Arc::new(stats)) {
150-
tracing::error!(err = ?e, "failed to update transport stats");
151-
}
146+
*self.stats.write().unwrap() = Arc::new(stats);
152147
}
153148
Err(e) => tracing::error!(errror = ?e, "failed to gather transport stats"),
154149
}

0 commit comments

Comments
 (0)