Skip to content

Commit 0aa18c0

Browse files
committed
Notify the Sender that the Encoder has been dropped
1 parent 886cd75 commit 0aa18c0

File tree

5 files changed

+100
-50
lines changed

5 files changed

+100
-50
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ async-std = { version = "1.6.0", features = ["unstable"] }
2121
http-types = "2.0.1"
2222
log = "0.4.8"
2323
memchr = "2.3.3"
24-
pin-project-lite = "0.1.4"
24+
pin-project = "0.4.22"
2525

2626
[dev-dependencies]
2727
femme = "2.0.0"

src/encoder.rs

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,26 @@ use async_std::io::Read as AsyncRead;
66
use async_std::prelude::*;
77
use async_std::task::{ready, Context, Poll};
88
use std::pin::Pin;
9+
use std::sync::atomic::{AtomicBool, Ordering};
10+
use std::sync::Arc;
911

10-
pin_project_lite::pin_project! {
11-
/// An SSE protocol encoder.
12-
#[derive(Debug)]
13-
pub struct Encoder {
14-
buf: Option<Vec<u8>>,
15-
#[pin]
16-
receiver: sync::Receiver<Vec<u8>>,
17-
cursor: usize,
12+
use pin_project::{pin_project, pinned_drop};
13+
14+
#[pin_project(PinnedDrop)]
15+
/// An SSE protocol encoder.
16+
#[derive(Debug)]
17+
pub struct Encoder {
18+
buf: Option<Vec<u8>>,
19+
#[pin]
20+
receiver: sync::Receiver<Vec<u8>>,
21+
cursor: usize,
22+
disconnected: Arc<AtomicBool>,
23+
}
24+
25+
#[pinned_drop]
26+
impl PinnedDrop for Encoder {
27+
fn drop(self: Pin<&mut Self>) {
28+
self.disconnected.store(true, Ordering::Relaxed);
1829
}
1930
}
2031

@@ -79,53 +90,80 @@ impl AsyncRead for Encoder {
7990
// }
8091

8192
/// The sending side of the encoder.
82-
#[derive(Debug)]
83-
pub struct Sender(sync::Sender<Vec<u8>>);
93+
#[derive(Debug, Clone)]
94+
pub struct Sender {
95+
sender: sync::Sender<Vec<u8>>,
96+
disconnected: Arc<std::sync::atomic::AtomicBool>,
97+
}
8498

8599
/// Create a new SSE encoder.
86100
pub fn encode() -> (Sender, Encoder) {
87101
let (sender, receiver) = sync::channel(1);
102+
let disconnected = Arc::new(AtomicBool::new(false));
103+
88104
let encoder = Encoder {
89105
receiver,
90106
buf: None,
91107
cursor: 0,
108+
disconnected: disconnected.clone(),
109+
};
110+
111+
let sender = Sender {
112+
sender,
113+
disconnected,
92114
};
93-
(Sender(sender), encoder)
115+
116+
(sender, encoder)
94117
}
95118

119+
/// An error that represents that the [Encoder] has been dropped.
120+
#[derive(Debug, Eq, PartialEq)]
121+
pub struct DisconnectedError;
122+
impl std::error::Error for DisconnectedError {}
123+
impl std::fmt::Display for DisconnectedError {
124+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125+
write!(f, "Disconnected")
126+
}
127+
}
128+
129+
#[must_use]
96130
impl Sender {
97131
/// Send a new message over SSE.
98-
pub async fn send(&self, name: &str, data: &str, id: Option<&str>) {
132+
pub async fn send(
133+
&self,
134+
name: &str,
135+
data: &str,
136+
id: Option<&str>,
137+
) -> Result<(), DisconnectedError> {
138+
if self.disconnected.load(Ordering::Relaxed) {
139+
return Err(DisconnectedError);
140+
}
141+
99142
// Write the event name
100143
let msg = format!("event:{}\n", name);
101-
self.0.send(msg.into_bytes()).await;
144+
self.sender.send(msg.into_bytes()).await;
102145

103146
// Write the id
104147
if let Some(id) = id {
105-
self.0.send(format!("id:{}\n", id).into_bytes()).await;
148+
self.sender.send(format!("id:{}\n", id).into_bytes()).await;
106149
}
107150

108151
// Write the data section, and end.
109152
let msg = format!("data:{}\n\n", data);
110-
self.0.send(msg.into_bytes()).await;
153+
self.sender.send(msg.into_bytes()).await;
154+
Ok(())
111155
}
112156

113157
/// Send a new "retry" message over SSE.
114158
pub async fn send_retry(&self, dur: Duration, id: Option<&str>) {
115159
// Write the id
116160
if let Some(id) = id {
117-
self.0.send(format!("id:{}\n", id).into_bytes()).await;
161+
self.sender.send(format!("id:{}\n", id).into_bytes()).await;
118162
}
119163

120164
// Write the retry section, and end.
121165
let dur = dur.as_secs_f64() as u64;
122166
let msg = format!("retry:{}\n\n", dur);
123-
self.0.send(msg.into_bytes()).await;
124-
}
125-
}
126-
127-
impl Clone for Sender {
128-
fn clone(&self) -> Self {
129-
Self(self.0.clone())
167+
self.sender.send(msg.into_bytes()).await;
130168
}
131169
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ mod lines;
4343
mod message;
4444

4545
pub use decoder::{decode, Decoder};
46-
pub use encoder::{encode, Encoder, Sender};
46+
pub use encoder::{encode, DisconnectedError, Encoder, Sender};
4747
pub use event::Event;
4848
pub use handshake::upgrade;
4949
pub use message::Message;

src/lines.rs

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,29 @@ use std::mem;
22
use std::pin::Pin;
33
use std::str;
44

5-
use pin_project_lite::pin_project;
5+
use pin_project::pin_project;
66

77
use async_std::io::{self, BufRead};
88
use async_std::stream::Stream;
99
use async_std::task::{ready, Context, Poll};
1010

11-
pin_project! {
12-
/// A stream of lines in a byte stream.
13-
///
14-
/// This stream is created by the [`lines`] method on types that implement [`BufRead`].
15-
///
16-
/// This type is an async version of [`std::io::Lines`].
17-
///
18-
/// [`lines`]: trait.BufRead.html#method.lines
19-
/// [`BufRead`]: trait.BufRead.html
20-
/// [`std::io::Lines`]: https://doc.rust-lang.org/std/io/struct.Lines.html
21-
#[derive(Debug)]
22-
pub(crate) struct Lines<R> {
23-
#[pin]
24-
pub(crate) reader: R,
25-
pub(crate) buf: String,
26-
pub(crate) bytes: Vec<u8>,
27-
pub(crate) read: usize,
28-
}
11+
/// A stream of lines in a byte stream.
12+
///
13+
/// This stream is created by the [`lines`] method on types that implement [`BufRead`].
14+
///
15+
/// This type is an async version of [`std::io::Lines`].
16+
///
17+
/// [`lines`]: trait.BufRead.html#method.lines
18+
/// [`BufRead`]: trait.BufRead.html
19+
/// [`std::io::Lines`]: https://doc.rust-lang.org/std/io/struct.Lines.html
20+
#[pin_project]
21+
#[derive(Debug)]
22+
pub(crate) struct Lines<R> {
23+
#[pin]
24+
pub(crate) reader: R,
25+
pub(crate) buf: String,
26+
pub(crate) bytes: Vec<u8>,
27+
pub(crate) read: usize,
2928
}
3029

3130
impl<R> Lines<R> {

tests/encode.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ fn assert_retry(event: &Event, dur: u64) {
2929
#[async_std::test]
3030
async fn encode_message() -> http_types::Result<()> {
3131
let (sender, encoder) = encode();
32-
task::spawn(async move {
33-
sender.send("cat", "chashu", None).await;
34-
});
32+
task::spawn(async move { sender.send("cat", "chashu", None).await });
3533

3634
let mut reader = decode(BufReader::new(encoder));
3735
let event = reader.next().await.unwrap()?;
@@ -42,9 +40,7 @@ async fn encode_message() -> http_types::Result<()> {
4240
#[async_std::test]
4341
async fn encode_message_with_id() -> http_types::Result<()> {
4442
let (sender, encoder) = encode();
45-
task::spawn(async move {
46-
sender.send("cat", "chashu", Some("0")).await;
47-
});
43+
task::spawn(async move { sender.send("cat", "chashu", Some("0")).await });
4844

4945
let mut reader = decode(BufReader::new(encoder));
5046
let event = reader.next().await.unwrap()?;
@@ -65,3 +61,20 @@ async fn encode_retry() -> http_types::Result<()> {
6561
assert_retry(&event, 12);
6662
Ok(())
6763
}
64+
65+
#[async_std::test]
66+
async fn dropping_encoder() -> http_types::Result<()> {
67+
let (sender, encoder) = encode();
68+
let reader = BufReader::new(encoder);
69+
let sender_clone = sender.clone();
70+
task::spawn(async move { sender_clone.send("cat", "chashu", Some("0")).await.unwrap() });
71+
72+
//move the encoder into Lines, which gets dropped after this
73+
assert_eq!(reader.lines().next().await.unwrap().unwrap(), "event:cat");
74+
75+
assert_eq!(
76+
sender.send("cat", "nori", None).await,
77+
Err(async_sse::DisconnectedError)
78+
);
79+
Ok(())
80+
}

0 commit comments

Comments
 (0)