Skip to content

Commit a5ff012

Browse files
refactor: rewrite wait_shutdown to not depend on tokio macros feature
1 parent e994b99 commit a5ff012

File tree

3 files changed

+36
-25
lines changed

3 files changed

+36
-25
lines changed

watermelon/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ rust-version.workspace = true
1313
features = ["websocket", "non-standard-zstd"]
1414

1515
[dependencies]
16-
tokio = { version = "1.44", features = ["macros", "rt", "sync", "time"] }
16+
tokio = { version = "1.44", features = ["rt", "sync", "time"] }
1717
arc-swap = "1"
1818
futures-core = "0.3"
1919
bytes = "1"

watermelon/src/client/mod.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ use std::{fmt::Write, num::NonZero, process::abort, sync::Arc, time::Duration};
55
use arc_swap::ArcSwapOption;
66
use bytes::Bytes;
77
use tokio::{
8-
select,
98
sync::{
109
mpsc::{self, Permit, error::TrySendError},
1110
oneshot,
@@ -39,8 +38,8 @@ use self::tests::TestHandler;
3938
use crate::{
4039
core::{MultiplexedSubscription, Subscription},
4140
handler::{
42-
ConnectHandlerError, Handler, HandlerCommand, HandlerOutput, MULTIPLEXED_SUBSCRIPTION_ID,
43-
RecycledHandler,
41+
ConnectHandlerError, FuseShutdown, Handler, HandlerCommand, HandlerOutput,
42+
MULTIPLEXED_SUBSCRIPTION_ID, RecycledHandler,
4443
},
4544
util::atomic::{AtomicU64, Ordering},
4645
};
@@ -581,12 +580,9 @@ async fn connect(
581580
let mut delay = initial_delay;
582581

583582
loop {
584-
select! {
585-
biased;
586-
() = recycle.wait_shutdown() => {
587-
return None;
588-
},
589-
() = sleep(delay) => {},
583+
match recycle.fuse_shutdown(sleep(delay)).await {
584+
FuseShutdown::Output(()) => {}
585+
FuseShutdown::Shutdown => return None,
590586
}
591587

592588
match Handler::connect(addr, builder, recycle).await {

watermelon/src/handler/mod.rs

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
use std::{
22
collections::{BTreeMap, VecDeque},
3-
future::Future,
3+
future::{self, Future},
44
mem,
55
num::NonZero,
66
ops::ControlFlow,
7-
pin::Pin,
7+
pin::{Pin, pin},
88
sync::Arc,
99
task::{Context, Poll},
1010
};
@@ -14,7 +14,6 @@ use bytes::Bytes;
1414
use thiserror::Error;
1515
use tokio::{
1616
net::TcpStream,
17-
select,
1817
sync::{
1918
mpsc::{self, error::TrySendError},
2019
oneshot,
@@ -155,17 +154,18 @@ impl Handler {
155154
flags.zstd_compression_level = builder.non_standard_zstd_compression_level;
156155
}
157156

158-
let (mut conn, info) = select! {
159-
biased;
160-
() = recycle.wait_shutdown() => {
161-
return Ok(None);
157+
let fut = timeout(
158+
builder.connect_timeout,
159+
easy_connect(addr, builder.auth_method.as_ref(), flags),
160+
);
161+
let (mut conn, info) = match recycle.fuse_shutdown(fut).await {
162+
FuseShutdown::Output(connect_result) => match connect_result {
163+
Ok(Ok(items)) => items,
164+
Ok(Err(err)) => return Err((ConnectHandlerError::Connect(err), recycle)),
165+
Err(_elapsed) => return Err((ConnectHandlerError::TimedOut, recycle)),
162166
},
163-
connect_result = timeout(builder.connect_timeout, easy_connect(addr, builder.auth_method.as_ref(), flags)) => {
164-
match connect_result {
165-
Ok(Ok(items)) => items,
166-
Ok(Err(err)) => return Err((ConnectHandlerError::Connect(err), recycle)),
167-
Err(_elapsed) => return Err((ConnectHandlerError::TimedOut, recycle)),
168-
}
167+
FuseShutdown::Shutdown => {
168+
return Ok(None);
169169
}
170170
};
171171

@@ -674,6 +674,12 @@ impl Handler {
674674
}
675675
}
676676

677+
#[derive(Debug)]
678+
pub(crate) enum FuseShutdown<T> {
679+
Output(T),
680+
Shutdown,
681+
}
682+
677683
impl RecycledHandler {
678684
pub(crate) fn new(
679685
commands: mpsc::Receiver<HandlerCommand>,
@@ -699,8 +705,17 @@ impl RecycledHandler {
699705
&self.multiplexed_subscription_prefix
700706
}
701707

702-
pub(crate) async fn wait_shutdown(&mut self) {
703-
let _ = self.shutdown_recv.recv().await;
708+
pub(crate) async fn fuse_shutdown<F: Future>(&mut self, fut: F) -> FuseShutdown<F::Output> {
709+
let mut fut = pin!(fut);
710+
711+
future::poll_fn(|cx| {
712+
if self.shutdown_recv.poll_recv(cx).is_ready() {
713+
Poll::Ready(FuseShutdown::Shutdown)
714+
} else {
715+
fut.as_mut().poll(cx).map(FuseShutdown::Output)
716+
}
717+
})
718+
.await
704719
}
705720
}
706721

0 commit comments

Comments
 (0)