Skip to content

Commit 69a8c7d

Browse files
committed
rework encoder
1 parent ddf689e commit 69a8c7d

File tree

7 files changed

+145
-82
lines changed

7 files changed

+145
-82
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ authors = [
1717
[features]
1818

1919
[dependencies]
20-
async-std = "1.5.0"
20+
async-std = { version = "1.5.0", features = ["unstable"] }
2121
http-types = "1.0.1"
2222
log = "0.4.8"
2323
memchr = "2.3.3"
2424
pin-project-lite = "0.1.4"
2525

2626
[dev-dependencies]
2727
femme = "1.3.0"
28-
async-std = { version = "1.5.0", features = ["attributes"] }
28+
async-std = { version = "1.5.0", features = ["attributes", "unstable"] }
2929

3030
[patch.crates-io]
3131
async-std = { path = "../async-std" }

src/encoder.rs

Lines changed: 102 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,127 @@
1-
use async_std::io::prelude::*;
2-
use async_std::io::Write as AsyncWrite;
1+
use async_std::sync;
32
use std::io;
43
use std::time::Duration;
54

6-
/// An SSE protocol encoder.
7-
#[derive(Debug)]
8-
pub struct Encoder<W> {
9-
writer: W,
10-
}
5+
use async_std::io::Read as AsyncRead;
6+
use async_std::prelude::*;
7+
use async_std::task::{ready, Context, Poll};
8+
use std::pin::Pin;
119

12-
/// Encode a new SSE connection.
13-
pub fn encode<W: AsyncWrite + Unpin>(writer: W) -> Encoder<W> {
14-
Encoder { writer }
10+
pin_project_lite::pin_project! {
11+
/// An SSE protocol encoder.
12+
#[derive(Debug)]
13+
pub struct Encoder {
14+
buf: Option<Vec<u8>>,
15+
#[pin]
16+
receiver: sync::Receiver<Vec<u8>>,
17+
cursor: usize,
18+
}
1519
}
1620

17-
impl<W> Encoder<W> {
18-
/// Access the inner writer from the Encoder.
19-
pub fn into_writer(self) -> W {
20-
self.writer
21+
impl AsyncRead for Encoder {
22+
fn poll_read(
23+
mut self: Pin<&mut Self>,
24+
cx: &mut Context<'_>,
25+
buf: &mut [u8],
26+
) -> Poll<io::Result<usize>> {
27+
// Request a new buffer if we don't have one yet.
28+
if let None = self.buf {
29+
self.buf = match ready!(Pin::new(&mut self.receiver).poll_next(cx)) {
30+
Some(buf) => {
31+
log::trace!("> Received a new buffer with len {}", buf.len());
32+
Some(buf)
33+
}
34+
None => {
35+
log::trace!("> Encoder done reading");
36+
return Poll::Ready(Ok(0));
37+
}
38+
};
39+
};
40+
41+
// Write the current buffer to completion.
42+
let local_buf = self.buf.as_mut().unwrap();
43+
let local_len = local_buf.len();
44+
let max = buf.len().min(local_buf.len());
45+
buf[..max].clone_from_slice(&local_buf[..max]);
46+
47+
self.cursor += max;
48+
49+
// Reset values if we're done reading.
50+
if self.cursor == local_len {
51+
self.buf = None;
52+
self.cursor = 0;
53+
};
54+
55+
// Return bytes read.
56+
Poll::Ready(Ok(max))
2157
}
2258
}
2359

24-
impl<W: AsyncWrite + Unpin> Encoder<W> {
60+
// impl AsyncBufRead for Encoder {
61+
// fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
62+
// match ready!(self.project().receiver.poll_next(cx)) {
63+
// Some(buf) => match &self.buf {
64+
// None => self.project().buf = &mut Some(buf),
65+
// Some(local_buf) => local_buf.extend(buf),
66+
// },
67+
// None => {
68+
// if let None = self.buf {
69+
// self.project().buf = &mut Some(vec![]);
70+
// };
71+
// }
72+
// };
73+
// Poll::Ready(Ok(self.buf.as_ref().unwrap()))
74+
// }
75+
76+
// fn consume(self: Pin<&mut Self>, amt: usize) {
77+
// Pin::new(self).cursor += amt;
78+
// }
79+
// }
80+
81+
/// The sending side of the encoder.
82+
#[derive(Debug)]
83+
pub struct Sender(sync::Sender<Vec<u8>>);
84+
85+
/// Create a new SSE encoder.
86+
pub fn encode() -> (Sender, Encoder) {
87+
let (sender, receiver) = sync::channel(1);
88+
let encoder = Encoder {
89+
receiver,
90+
buf: None,
91+
cursor: 0,
92+
};
93+
(Sender(sender), encoder)
94+
}
95+
96+
impl Sender {
2597
/// Send a new message over SSE.
26-
pub async fn send(&mut self, name: &str, data: &[u8], id: Option<&str>) -> io::Result<()> {
98+
pub async fn send(&self, name: &str, data: &[u8], id: Option<&str>) {
2799
// Write the event name
28-
self.writer.write_all(b"event:").await?;
29-
self.writer.write_all(name.as_bytes()).await?;
30-
self.writer.write_all(b"\n").await?;
100+
let msg = format!("event:{}\n", name);
101+
self.0.send(msg.into_bytes()).await;
31102

32103
// Write the id
33104
if let Some(id) = id {
34-
self.writer.write_all(b"id:").await?;
35-
self.writer.write_all(id.as_bytes()).await?;
36-
self.writer.write_all(b"\n").await?;
105+
self.0.send(format!("id:{}\n", id).into_bytes()).await;
37106
}
38107

39-
// Write the section
40-
self.writer.write_all(b"data:").await?;
41-
self.writer.write_all(data).await?;
42-
self.writer.write_all(b"\n").await?;
43-
44-
// Finalize the message
45-
self.writer.write_all(b"\n").await?;
46-
47-
Ok(())
108+
// Write the data section, and end.
109+
let mut msg = b"data:".to_vec();
110+
msg.extend_from_slice(data);
111+
msg.extend_from_slice(b"\n\n");
112+
self.0.send(msg).await;
48113
}
49114

50115
/// Send a new "retry" message over SSE.
51-
pub async fn send_retry(&mut self, dur: Duration, id: Option<&str>) -> io::Result<()> {
116+
pub async fn send_retry(&self, dur: Duration, id: Option<&str>) {
52117
// Write the id
53118
if let Some(id) = id {
54-
self.writer.write_all(b"id:").await?;
55-
self.writer.write_all(id.as_bytes()).await?;
56-
self.writer.write_all(b"\n").await?;
119+
self.0.send(format!("id:{}\n", id).into_bytes()).await;
57120
}
58121

59-
// Write the section
60-
self.writer.write_all(b"retry:").await?;
61-
self.writer
62-
.write_all(&format!("{}", dur.as_secs_f64() as u64).as_bytes())
63-
.await?;
64-
self.writer.write_all(b"\n").await?;
65-
66-
// Finalize the message
67-
self.writer.write_all(b"\n").await?;
68-
69-
Ok(())
122+
// Write the retry section, and end.
123+
let dur = dur.as_secs_f64() as u64;
124+
let msg = format!("retry:{}\n\n", dur);
125+
self.0.send(msg.into_bytes()).await;
70126
}
71127
}

src/event.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::Message;
22

33
use std::time::Duration;
44

5-
/// The kind of event sent.
5+
/// The kind of SSE event sent.
66
#[derive(Debug, Eq, PartialEq)]
77
pub enum Event {
88
/// A retry frame, signaling a new retry duration must be used..

src/handshake.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
/// Upgrade an HTTP connection into an SSE session.
2+
pub fn upgrade(headers: &mut impl AsMut<http_types::Headers>) -> http_types::Result<()> {
3+
let headers = headers.as_mut();
4+
headers.insert("Cache-Control", "no-cache")?;
5+
headers.insert("Content-Type", "text/event-stream")?;
6+
Ok(())
7+
}

src/lib.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,20 @@
44
//!
55
//! ```
66
//! use async_sse::{decode, encode, Event};
7-
//! use async_std::io::Cursor;
87
//! use async_std::prelude::*;
8+
//! use async_std::io::BufReader;
9+
//! use async_std::task;
910
//!
1011
//! #[async_std::main]
1112
//! async fn main() -> http_types::Result<()> {
12-
//! let buf = Cursor::new(vec![]);
13-
//!
14-
//! // Encode messages to an AsyncWrite.
15-
//! let mut encoder = encode(buf);
16-
//! encoder.send("cat", b"chashu", None).await?;
17-
//!
18-
//! let mut buf = encoder.into_writer();
19-
//! buf.set_position(0);
20-
//!
21-
//! // Decode messages from an AsyncRead.
22-
//! let mut reader = decode(buf);
13+
//! // Create an encoder + sender pair and send a message.
14+
//! let (sender, encoder) = encode();
15+
//! task::spawn(async move {
16+
//! sender.send("cat", b"chashu", None).await;
17+
//! });
18+
//!
19+
//! // Decode messages using a decoder.
20+
//! let mut reader = decode(BufReader::new(encoder));
2321
//! let event = reader.next().await.unwrap()?;
2422
//! // Match and handle the event
2523
//!
@@ -40,12 +38,14 @@
4038
mod decoder;
4139
mod encoder;
4240
mod event;
41+
mod handshake;
4342
mod lines;
4443
mod message;
4544

4645
pub use decoder::{decode, Decoder};
47-
pub use encoder::{encode, Encoder};
46+
pub use encoder::{encode, Encoder, Sender};
4847
pub use event::Event;
48+
pub use handshake::upgrade;
4949
pub use message::Message;
5050

5151
pub(crate) use lines::Lines;

src/message.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/// A data event.
1+
/// An SSE event with a data payload.
22
#[derive(Debug, PartialEq, Eq, Hash)]
33
pub struct Message {
44
/// The ID of this event.

tests/encode.rs

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use async_sse::{decode, encode, Event};
2-
use async_std::io::Cursor;
2+
use async_std::io::BufReader;
33
use async_std::prelude::*;
4+
use async_std::task;
45
use std::time::Duration;
56

67
/// Assert a Message.
@@ -24,43 +25,42 @@ fn assert_retry(event: &Event, dur: u64) {
2425
assert_eq!(dur, &expected);
2526
}
2627
}
28+
2729
#[async_std::test]
2830
async fn encode_message() -> http_types::Result<()> {
29-
let buf = Cursor::new(vec![]);
30-
let mut encoder = encode(buf);
31-
encoder.send("cat", b"chashu", None).await?;
32-
let mut buf = encoder.into_writer();
33-
buf.set_position(0);
31+
let (sender, encoder) = encode();
32+
task::spawn(async move {
33+
sender.send("cat", b"chashu", None).await;
34+
});
3435

35-
let mut reader = decode(buf);
36+
let mut reader = decode(BufReader::new(encoder));
3637
let event = reader.next().await.unwrap()?;
3738
assert_message(&event, "cat", "chashu", None);
3839
Ok(())
3940
}
4041

4142
#[async_std::test]
4243
async fn encode_message_with_id() -> http_types::Result<()> {
43-
let buf = Cursor::new(vec![]);
44-
let mut encoder = encode(buf);
45-
encoder.send("cat", b"chashu", Some("0")).await?;
46-
let mut buf = encoder.into_writer();
47-
buf.set_position(0);
44+
let (sender, encoder) = encode();
45+
task::spawn(async move {
46+
sender.send("cat", b"chashu", Some("0")).await;
47+
});
4848

49-
let mut reader = decode(buf);
49+
let mut reader = decode(BufReader::new(encoder));
5050
let event = reader.next().await.unwrap()?;
5151
assert_message(&event, "cat", "chashu", Some("0"));
5252
Ok(())
5353
}
5454

5555
#[async_std::test]
5656
async fn encode_retry() -> http_types::Result<()> {
57-
let buf = Cursor::new(vec![]);
58-
let mut encoder = encode(buf);
59-
encoder.send_retry(Duration::from_secs(12), None).await?;
60-
let mut buf = encoder.into_writer();
61-
buf.set_position(0);
57+
let (sender, encoder) = encode();
58+
task::spawn(async move {
59+
let dur = Duration::from_secs(12);
60+
sender.send_retry(dur, None).await;
61+
});
6262

63-
let mut reader = decode(buf);
63+
let mut reader = decode(BufReader::new(encoder));
6464
let event = reader.next().await.unwrap()?;
6565
assert_retry(&event, 12);
6666
Ok(())

0 commit comments

Comments
 (0)