Skip to content

Commit e3db5a5

Browse files
fix(futures-bounded): register replaced Streams/Futures as ready
Currently, when a `Stream` or `Future` is replaced with a new one, it might happen that we miss a task wake-up and thus the task polling `FuturesMap` or `StreamMap` is never called again. This can be fixed by first removing the old `Stream`/`Future` and properly adding a new one via `.push`. The inner `SelectAll` calls a waker in that case which allows the outer task to continue. Pull-Request: #4865.
1 parent b6eb2bf commit e3db5a5

File tree

7 files changed

+153
-62
lines changed

7 files changed

+153
-62
lines changed

Cargo.lock

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ rust-version = "1.73.0"
7171

7272
[workspace.dependencies]
7373
asynchronous-codec = { version = "0.7.0" }
74-
futures-bounded = { version = "0.2.1", path = "misc/futures-bounded" }
74+
futures-bounded = { version = "0.2.2", path = "misc/futures-bounded" }
7575
libp2p = { version = "0.53.0", path = "libp2p" }
7676
libp2p-allow-block-list = { version = "0.3.0", path = "misc/allow-block-list" }
7777
libp2p-autonat = { version = "0.12.0", path = "protocols/autonat" }

misc/futures-bounded/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
## 0.2.2
2+
3+
- Fix an issue where `{Futures,Stream}Map` returns `Poll::Pending` despite being ready after an item has been replaced as part of `try_push`.
4+
See [PR 4865](https://github.com/libp2p/rust-lib2pp/pulls/4865).
5+
16
## 0.2.1
27

38
- Add `.len()` getter to `FuturesMap`, `FuturesSet`, `StreamMap` and `StreamSet`.

misc/futures-bounded/Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "futures-bounded"
3-
version = "0.2.1"
3+
version = "0.2.2"
44
edition = "2021"
55
rust-version.workspace = true
66
license = "MIT"
@@ -17,7 +17,8 @@ futures-util = { version = "0.3.29" }
1717
futures-timer = "3.0.2"
1818

1919
[dev-dependencies]
20-
tokio = { version = "1.34.0", features = ["macros", "rt"] }
20+
tokio = { version = "1.34.0", features = ["macros", "rt", "sync"] }
21+
futures = "0.3.28"
2122

2223
[lints]
2324
workspace = true

misc/futures-bounded/src/futures_map.rs

Lines changed: 89 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use std::future::Future;
22
use std::hash::Hash;
3-
use std::mem;
43
use std::pin::Pin;
54
use std::task::{Context, Poll, Waker};
65
use std::time::Duration;
6+
use std::{future, mem};
77

88
use futures_timer::Delay;
99
use futures_util::future::BoxFuture;
@@ -38,6 +38,7 @@ impl<ID, O> FuturesMap<ID, O> {
3838
impl<ID, O> FuturesMap<ID, O>
3939
where
4040
ID: Clone + Hash + Eq + Send + Unpin + 'static,
41+
O: 'static,
4142
{
4243
/// Push a future into the map.
4344
///
@@ -58,32 +59,30 @@ where
5859
waker.wake();
5960
}
6061

61-
match self.inner.iter_mut().find(|tagged| tagged.tag == future_id) {
62-
None => {
63-
self.inner.push(TaggedFuture {
64-
tag: future_id,
65-
inner: TimeoutFuture {
66-
inner: future.boxed(),
67-
timeout: Delay::new(self.timeout),
68-
},
69-
});
70-
71-
Ok(())
72-
}
73-
Some(existing) => {
74-
let old_future = mem::replace(
75-
&mut existing.inner,
76-
TimeoutFuture {
77-
inner: future.boxed(),
78-
timeout: Delay::new(self.timeout),
79-
},
80-
);
81-
82-
Err(PushError::Replaced(old_future.inner))
83-
}
62+
let old = self.remove(future_id.clone());
63+
self.inner.push(TaggedFuture {
64+
tag: future_id,
65+
inner: TimeoutFuture {
66+
inner: future.boxed(),
67+
timeout: Delay::new(self.timeout),
68+
cancelled: false,
69+
},
70+
});
71+
match old {
72+
None => Ok(()),
73+
Some(old) => Err(PushError::Replaced(old)),
8474
}
8575
}
8676

77+
pub fn remove(&mut self, id: ID) -> Option<BoxFuture<'static, O>> {
78+
let tagged = self.inner.iter_mut().find(|s| s.tag == id)?;
79+
80+
let inner = mem::replace(&mut tagged.inner.inner, future::pending().boxed());
81+
tagged.inner.cancelled = true;
82+
83+
Some(inner)
84+
}
85+
8786
pub fn len(&self) -> usize {
8887
self.inner.len()
8988
}
@@ -104,39 +103,55 @@ where
104103
}
105104

106105
pub fn poll_unpin(&mut self, cx: &mut Context<'_>) -> Poll<(ID, Result<O, Timeout>)> {
107-
let maybe_result = futures_util::ready!(self.inner.poll_next_unpin(cx));
106+
loop {
107+
let maybe_result = futures_util::ready!(self.inner.poll_next_unpin(cx));
108108

109-
match maybe_result {
110-
None => {
111-
self.empty_waker = Some(cx.waker().clone());
112-
Poll::Pending
109+
match maybe_result {
110+
None => {
111+
self.empty_waker = Some(cx.waker().clone());
112+
return Poll::Pending;
113+
}
114+
Some((id, Ok(output))) => return Poll::Ready((id, Ok(output))),
115+
Some((id, Err(TimeoutError::Timeout))) => {
116+
return Poll::Ready((id, Err(Timeout::new(self.timeout))))
117+
}
118+
Some((_, Err(TimeoutError::Cancelled))) => continue,
113119
}
114-
Some((id, Ok(output))) => Poll::Ready((id, Ok(output))),
115-
Some((id, Err(_timeout))) => Poll::Ready((id, Err(Timeout::new(self.timeout)))),
116120
}
117121
}
118122
}
119123

120124
struct TimeoutFuture<F> {
121125
inner: F,
122126
timeout: Delay,
127+
128+
cancelled: bool,
123129
}
124130

125131
impl<F> Future for TimeoutFuture<F>
126132
where
127133
F: Future + Unpin,
128134
{
129-
type Output = Result<F::Output, ()>;
135+
type Output = Result<F::Output, TimeoutError>;
130136

131137
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
138+
if self.cancelled {
139+
return Poll::Ready(Err(TimeoutError::Cancelled));
140+
}
141+
132142
if self.timeout.poll_unpin(cx).is_ready() {
133-
return Poll::Ready(Err(()));
143+
return Poll::Ready(Err(TimeoutError::Timeout));
134144
}
135145

136146
self.inner.poll_unpin(cx).map(Ok)
137147
}
138148
}
139149

150+
enum TimeoutError {
151+
Timeout,
152+
Cancelled,
153+
}
154+
140155
struct TaggedFuture<T, F> {
141156
tag: T,
142157
inner: F,
@@ -158,6 +173,8 @@ where
158173

159174
#[cfg(test)]
160175
mod tests {
176+
use futures::channel::oneshot;
177+
use futures_util::task::noop_waker_ref;
161178
use std::future::{pending, poll_fn, ready};
162179
use std::pin::Pin;
163180
use std::time::Instant;
@@ -197,6 +214,45 @@ mod tests {
197214
assert!(result.is_err())
198215
}
199216

217+
#[test]
218+
fn resources_of_removed_future_are_cleaned_up() {
219+
let mut futures = FuturesMap::new(Duration::from_millis(100), 1);
220+
221+
let _ = futures.try_push("ID", pending::<()>());
222+
futures.remove("ID");
223+
224+
let poll = futures.poll_unpin(&mut Context::from_waker(noop_waker_ref()));
225+
assert!(poll.is_pending());
226+
227+
assert_eq!(futures.len(), 0);
228+
}
229+
230+
#[tokio::test]
231+
async fn replaced_pending_future_is_polled() {
232+
let mut streams = FuturesMap::new(Duration::from_millis(100), 3);
233+
234+
let (_tx1, rx1) = oneshot::channel();
235+
let (tx2, rx2) = oneshot::channel();
236+
237+
let _ = streams.try_push("ID1", rx1);
238+
let _ = streams.try_push("ID2", rx2);
239+
240+
let _ = tx2.send(2);
241+
let (id, res) = poll_fn(|cx| streams.poll_unpin(cx)).await;
242+
assert_eq!(id, "ID2");
243+
assert_eq!(res.unwrap().unwrap(), 2);
244+
245+
let (new_tx1, new_rx1) = oneshot::channel();
246+
let replaced = streams.try_push("ID1", new_rx1);
247+
assert!(matches!(replaced.unwrap_err(), PushError::Replaced(_)));
248+
249+
let _ = new_tx1.send(4);
250+
let (id, res) = poll_fn(|cx| streams.poll_unpin(cx)).await;
251+
252+
assert_eq!(id, "ID1");
253+
assert_eq!(res.unwrap().unwrap(), 4);
254+
}
255+
200256
// Each future causes a delay, `Task` only has a capacity of 1, meaning they must be processed in sequence.
201257
// We stop after NUM_FUTURES tasks, meaning the overall execution must at least take DELAY * NUM_FUTURES.
202258
#[tokio::test]

misc/futures-bounded/src/futures_set.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ impl<O> FuturesSet<O> {
2323
}
2424
}
2525

26-
impl<O> FuturesSet<O> {
26+
impl<O> FuturesSet<O>
27+
where
28+
O: 'static,
29+
{
2730
/// Push a future into the list.
2831
///
2932
/// This method adds the given future to the list.

misc/futures-bounded/src/stream_map.rs

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -53,33 +53,22 @@ where
5353
waker.wake();
5454
}
5555

56-
match self.inner.iter_mut().find(|tagged| tagged.key == id) {
57-
None => {
58-
self.inner.push(TaggedStream::new(
59-
id,
60-
TimeoutStream {
61-
inner: stream.boxed(),
62-
timeout: Delay::new(self.timeout),
63-
},
64-
));
65-
66-
Ok(())
67-
}
68-
Some(existing) => {
69-
let old = mem::replace(
70-
&mut existing.inner,
71-
TimeoutStream {
72-
inner: stream.boxed(),
73-
timeout: Delay::new(self.timeout),
74-
},
75-
);
76-
77-
Err(PushError::Replaced(old.inner))
78-
}
56+
let old = self.remove(id.clone());
57+
self.inner.push(TaggedStream::new(
58+
id,
59+
TimeoutStream {
60+
inner: stream.boxed(),
61+
timeout: Delay::new(self.timeout),
62+
},
63+
));
64+
65+
match old {
66+
None => Ok(()),
67+
Some(old) => Err(PushError::Replaced(old)),
7968
}
8069
}
8170

82-
pub fn remove(&mut self, id: ID) -> Option<BoxStream<O>> {
71+
pub fn remove(&mut self, id: ID) -> Option<BoxStream<'static, O>> {
8372
let tagged = self.inner.iter_mut().find(|s| s.key == id)?;
8473

8574
let inner = mem::replace(&mut tagged.inner.inner, stream::pending().boxed());
@@ -189,7 +178,9 @@ where
189178

190179
#[cfg(test)]
191180
mod tests {
181+
use futures::channel::mpsc;
192182
use futures_util::stream::{once, pending};
183+
use futures_util::SinkExt;
193184
use std::future::{poll_fn, ready, Future};
194185
use std::pin::Pin;
195186
use std::time::Instant;
@@ -266,6 +257,40 @@ mod tests {
266257
);
267258
}
268259

260+
#[tokio::test]
261+
async fn replaced_stream_is_still_registered() {
262+
let mut streams = StreamMap::new(Duration::from_millis(100), 3);
263+
264+
let (mut tx1, rx1) = mpsc::channel(5);
265+
let (mut tx2, rx2) = mpsc::channel(5);
266+
267+
let _ = streams.try_push("ID1", rx1);
268+
let _ = streams.try_push("ID2", rx2);
269+
270+
let _ = tx2.send(2).await;
271+
let _ = tx1.send(1).await;
272+
let _ = tx2.send(3).await;
273+
let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
274+
assert_eq!(id, "ID1");
275+
assert_eq!(res.unwrap().unwrap(), 1);
276+
let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
277+
assert_eq!(id, "ID2");
278+
assert_eq!(res.unwrap().unwrap(), 2);
279+
let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
280+
assert_eq!(id, "ID2");
281+
assert_eq!(res.unwrap().unwrap(), 3);
282+
283+
let (mut new_tx1, new_rx1) = mpsc::channel(5);
284+
let replaced = streams.try_push("ID1", new_rx1);
285+
assert!(matches!(replaced.unwrap_err(), PushError::Replaced(_)));
286+
287+
let _ = new_tx1.send(4).await;
288+
let (id, res) = poll_fn(|cx| streams.poll_next_unpin(cx)).await;
289+
290+
assert_eq!(id, "ID1");
291+
assert_eq!(res.unwrap().unwrap(), 4);
292+
}
293+
269294
// Each stream emits 1 item with delay, `Task` only has a capacity of 1, meaning they must be processed in sequence.
270295
// We stop after NUM_STREAMS tasks, meaning the overall execution must at least take DELAY * NUM_STREAMS.
271296
#[tokio::test]

0 commit comments

Comments
 (0)