Skip to content

Commit a39b30f

Browse files
committed
Use state machine for server encoder as well
1 parent daaa9cc commit a39b30f

File tree

3 files changed

+72
-41
lines changed

3 files changed

+72
-41
lines changed

src/chunked/encoder.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ const CRLF_LEN: usize = 2;
1212
/// The encoder state.
1313
#[derive(Debug)]
1414
enum State {
15-
/// Starting state.
16-
Start,
15+
/// Initing state.
16+
Init,
1717
/// Streaming out chunks.
1818
EncodeChunks,
1919
/// No more chunks to stream, mark the end.
@@ -41,7 +41,7 @@ impl ChunkedEncoder {
4141
/// Create a new instance.
4242
pub(crate) fn new() -> Self {
4343
Self {
44-
state: State::Start,
44+
state: State::Init,
4545
bytes_written: 0,
4646
}
4747
}
@@ -72,15 +72,18 @@ impl ChunkedEncoder {
7272
buf: &mut [u8],
7373
) -> Poll<io::Result<usize>> {
7474
self.bytes_written = 0;
75-
match self.state {
76-
State::Start => self.init(res, cx, buf),
75+
let res = match self.state {
76+
State::Init => self.init(res, cx, buf),
7777
State::EncodeChunks => self.encode_chunks(res, cx, buf),
7878
State::EndOfChunks => self.encode_chunks_eos(res, cx, buf),
7979
State::ReceiveTrailers => self.encode_trailers(res, cx, buf),
8080
State::EncodeTrailers => self.encode_trailers(res, cx, buf),
8181
State::EndOfStream => self.encode_eos(cx, buf),
82-
State::End => Poll::Ready(Ok(0)),
83-
}
82+
State::End => Poll::Ready(Ok(self.bytes_written)),
83+
};
84+
85+
log::trace!("ChunkedEncoder {} bytes written", self.bytes_written);
86+
res
8487
}
8588

8689
/// Switch the internal state to a new state.
@@ -90,7 +93,7 @@ impl ChunkedEncoder {
9093

9194
#[cfg(debug_assertions)]
9295
match self.state {
93-
Start => assert!(matches!(state, EncodeChunks)),
96+
Init => assert!(matches!(state, EncodeChunks)),
9497
EncodeChunks => assert!(matches!(state, EndOfChunks)),
9598
EndOfChunks => assert!(matches!(state, ReceiveTrailers)),
9699
ReceiveTrailers => assert!(matches!(state, EncodeTrailers | EndOfStream)),
@@ -180,7 +183,6 @@ impl ChunkedEncoder {
180183
self.bytes_written += CRLF_LEN;
181184

182185
// Finally return how many bytes we've written to the buffer.
183-
log::trace!("sending {} bytes", self.bytes_written);
184186
Poll::Ready(Ok(self.bytes_written))
185187
}
186188

src/server/encode.rs

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pub(crate) struct Encoder {
1818
/// HTTP headers to be sent.
1919
res: Response,
2020
/// The state of the encoding process
21-
state: EncoderState,
21+
state: State,
2222
/// Track bytes read in a call to poll_read.
2323
bytes_read: usize,
2424
/// The data we're writing as part of the head section.
@@ -36,20 +36,21 @@ pub(crate) struct Encoder {
3636
}
3737

3838
#[derive(Debug)]
39-
enum EncoderState {
40-
Start,
41-
Head,
39+
enum State {
40+
Init,
41+
ComputeHead,
42+
EncodeHead,
4243
FixedBody,
4344
ChunkedBody,
44-
Done,
45+
End,
4546
}
4647

4748
impl Encoder {
48-
/// Create a new instance.
49-
pub(crate) fn encode(res: Response) -> Self {
49+
/// Create a new instance of Encoder.
50+
pub(crate) fn new(res: Response) -> Self {
5051
Self {
5152
res,
52-
state: EncoderState::Start,
53+
state: State::Init,
5354
bytes_read: 0,
5455
head: vec![],
5556
head_bytes_read: 0,
@@ -58,14 +59,51 @@ impl Encoder {
5859
chunked: ChunkedEncoder::new(),
5960
}
6061
}
61-
}
6262

63-
impl Encoder {
64-
// Encode the headers to a buffer, the first time we poll.
65-
fn encode_start(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
66-
log::trace!("Server response encoding: start");
67-
self.state = EncoderState::Head;
63+
pub(crate) fn encode(
64+
&mut self,
65+
cx: &mut Context<'_>,
66+
buf: &mut [u8],
67+
) -> Poll<io::Result<usize>> {
68+
self.bytes_read = 0;
69+
let res = match self.state {
70+
State::Init => self.init(cx, buf),
71+
State::ComputeHead => self.compute_head(cx, buf),
72+
State::EncodeHead => self.encode_head(cx, buf),
73+
State::FixedBody => self.encode_fixed_body(cx, buf),
74+
State::ChunkedBody => self.encode_chunked_body(cx, buf),
75+
State::End => Poll::Ready(Ok(self.bytes_read)),
76+
};
77+
log::trace!("ServerEncoder {} bytes written", self.bytes_read);
78+
res
79+
}
80+
81+
/// Switch the internal state to a new state.
82+
fn set_state(&mut self, state: State) {
83+
use State::*;
84+
log::trace!("Server Encoder state: {:?} -> {:?}", self.state, state);
6885

86+
#[cfg(debug_assertions)]
87+
match self.state {
88+
Init => assert!(matches!(state, ComputeHead)),
89+
ComputeHead => assert!(matches!(state, EncodeHead)),
90+
EncodeHead => assert!(matches!(state, ChunkedBody | FixedBody)),
91+
FixedBody => assert!(matches!(state, End)),
92+
ChunkedBody => assert!(matches!(state, End)),
93+
End => panic!("No state transitions allowed after the stream has ended"),
94+
}
95+
96+
self.state = state;
97+
}
98+
99+
/// Initialize to the first state.
100+
fn init(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
101+
self.set_state(State::ComputeHead);
102+
self.compute_head(cx, buf)
103+
}
104+
105+
/// Encode the headers to a buffer, the first time we poll.
106+
fn compute_head(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
69107
let reason = self.res.status().canonical_reason();
70108
let status = self.res.status();
71109
std::io::Write::write_fmt(
@@ -97,6 +135,8 @@ impl Encoder {
97135
}
98136

99137
std::io::Write::write_fmt(&mut self.head, format_args!("\r\n"))?;
138+
139+
self.set_state(State::EncodeHead);
100140
self.encode_head(cx, buf)
101141
}
102142

@@ -118,12 +158,12 @@ impl Encoder {
118158
match self.res.len() {
119159
Some(body_len) => {
120160
self.body_len = body_len;
121-
self.state = EncoderState::FixedBody;
161+
self.state = State::FixedBody;
122162
log::trace!("Server response encoding: fixed length body");
123163
return self.encode_fixed_body(cx, buf);
124164
}
125165
None => {
126-
self.state = EncoderState::ChunkedBody;
166+
self.state = State::ChunkedBody;
127167
log::trace!("Server response encoding: chunked body");
128168
return self.encode_chunked_body(cx, buf);
129169
}
@@ -177,12 +217,12 @@ impl Encoder {
177217

178218
if self.body_len == self.body_bytes_read {
179219
// If we've read the `len` number of bytes, end
180-
self.state = EncoderState::Done;
220+
self.set_state(State::End);
181221
return Poll::Ready(Ok(self.bytes_read));
182222
} else if new_body_bytes_read == 0 {
183223
// If we've reached unexpected EOF, end anyway
184224
// TODO: do something?
185-
self.state = EncoderState::Done;
225+
self.set_state(State::End);
186226
return Poll::Ready(Ok(self.bytes_read));
187227
} else {
188228
self.encode_fixed_body(cx, buf)
@@ -201,7 +241,7 @@ impl Encoder {
201241
Poll::Ready(Ok(read)) => {
202242
self.bytes_read += read;
203243
if self.bytes_read == 0 {
204-
self.state = EncoderState::Done
244+
self.set_state(State::End);
205245
}
206246
Poll::Ready(Ok(self.bytes_read))
207247
}
@@ -222,17 +262,6 @@ impl Read for Encoder {
222262
cx: &mut Context<'_>,
223263
buf: &mut [u8],
224264
) -> Poll<io::Result<usize>> {
225-
// we keep track how many bytes of the head and body we've read
226-
// in this call of `poll_read`
227-
self.bytes_read = 0;
228-
let res = match self.state {
229-
EncoderState::Start => self.encode_start(cx, buf),
230-
EncoderState::Head => self.encode_head(cx, buf),
231-
EncoderState::FixedBody => self.encode_fixed_body(cx, buf),
232-
EncoderState::ChunkedBody => self.encode_chunked_body(cx, buf),
233-
EncoderState::Done => Poll::Ready(Ok(0)),
234-
};
235-
// dbg!(String::from_utf8(buf[..self.bytes_read].to_vec()).unwrap());
236-
res
265+
self.encode(cx, buf)
237266
}
238267
}

src/server/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ where
7373

7474
// Pass the request to the endpoint and encode the response.
7575
let res = endpoint(req).await?;
76-
let mut encoder = Encoder::encode(res);
76+
let mut encoder = Encoder::new(res);
7777

7878
// Stream the response to the writer.
7979
io::copy(&mut encoder, &mut io).await?;

0 commit comments

Comments
 (0)