Skip to content

Commit 6b4c34a

Browse files
committed
Simplify state machine internals more
1 parent 4cf4c45 commit 6b4c34a

File tree

2 files changed

+91
-95
lines changed

2 files changed

+91
-95
lines changed

src/chunked/encoder.rs

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ const CRLF_LEN: usize = 2;
1313
#[derive(Debug)]
1414
enum State {
1515
/// Starting state.
16-
Init,
16+
Start,
1717
/// Streaming out chunks.
1818
EncodeChunks,
1919
/// No more chunks to stream, mark the end.
@@ -41,7 +41,7 @@ impl ChunkedEncoder {
4141
/// Create a new instance.
4242
pub(crate) fn new() -> Self {
4343
Self {
44-
state: State::Init,
44+
state: State::Start,
4545
bytes_written: 0,
4646
}
4747
}
@@ -72,48 +72,53 @@ impl ChunkedEncoder {
7272
buf: &mut [u8],
7373
) -> Poll<io::Result<usize>> {
7474
self.bytes_written = 0;
75-
let res = match self.state {
76-
State::Init => self.init(res, cx, buf),
75+
let res = self.exec(res, cx, buf);
76+
log::trace!("ChunkedEncoder {} bytes written", self.bytes_written);
77+
res
78+
}
79+
80+
/// Execute the right method for the current state.
81+
fn exec(
82+
&mut self,
83+
res: &mut Response,
84+
cx: &mut Context<'_>,
85+
buf: &mut [u8],
86+
) -> Poll<io::Result<usize>> {
87+
match self.state {
88+
State::Start => self.set_state(State::EncodeChunks, res, cx, buf),
7789
State::EncodeChunks => self.encode_chunks(res, cx, buf),
7890
State::EndOfChunks => self.encode_chunks_eos(res, cx, buf),
79-
State::ReceiveTrailers => self.encode_trailers(res, cx, buf),
91+
State::ReceiveTrailers => self.receive_trailers(res, cx, buf),
8092
State::EncodeTrailers => self.encode_trailers(res, cx, buf),
81-
State::EndOfStream => self.encode_eos(cx, buf),
93+
State::EndOfStream => self.encode_eos(res, cx, buf),
8294
State::End => Poll::Ready(Ok(self.bytes_written)),
83-
};
84-
85-
log::trace!("ChunkedEncoder {} bytes written", self.bytes_written);
86-
res
95+
}
8796
}
8897

8998
/// Switch the internal state to a new state.
90-
fn set_state(&mut self, state: State) {
99+
fn set_state(
100+
&mut self,
101+
state: State,
102+
res: &mut Response,
103+
cx: &mut Context<'_>,
104+
buf: &mut [u8],
105+
) -> Poll<io::Result<usize>> {
91106
use State::*;
92107
log::trace!("ChunkedEncoder state: {:?} -> {:?}", self.state, state);
93108

94109
#[cfg(debug_assertions)]
95110
match self.state {
96-
Init => assert!(matches!(state, EncodeChunks)),
111+
Start => assert!(matches!(state, EncodeChunks)),
97112
EncodeChunks => assert!(matches!(state, EndOfChunks)),
98113
EndOfChunks => assert!(matches!(state, ReceiveTrailers)),
99114
ReceiveTrailers => assert!(matches!(state, EncodeTrailers | EndOfStream)),
100115
EncodeTrailers => assert!(matches!(state, EndOfStream)),
101116
EndOfStream => assert!(matches!(state, End)),
102-
End => panic!("No state transitions allowed after the stream has ended"),
117+
End => panic!("No state transitions allowed after the ChunkedEncoder has ended"),
103118
}
104119

105120
self.state = state;
106-
}
107-
108-
/// Init encoding.
109-
fn init(
110-
&mut self,
111-
res: &mut Response,
112-
cx: &mut Context<'_>,
113-
buf: &mut [u8],
114-
) -> Poll<io::Result<usize>> {
115-
self.set_state(State::EncodeChunks);
116-
self.encode_chunks(res, cx, buf)
121+
self.exec(res, cx, buf)
117122
}
118123

119124
/// Stream out data using chunked encoding.
@@ -137,8 +142,7 @@ impl ChunkedEncoder {
137142
// If the stream doesn't have any more bytes left to read we're done
138143
// sending chunks and it's time to move on.
139144
if src.len() == 0 {
140-
self.set_state(State::EndOfChunks);
141-
return self.encode_chunks_eos(res, cx, buf);
145+
return self.set_state(State::EndOfChunks, res, cx, buf);
142146
}
143147

144148
// Each chunk is prefixed with the length of the data in hex, then a
@@ -205,8 +209,7 @@ impl ChunkedEncoder {
205209
buf[idx + 2] = LF;
206210
self.bytes_written += 1 + CRLF_LEN;
207211

208-
self.set_state(State::ReceiveTrailers);
209-
return self.receive_trailers(res, cx, buf);
212+
self.set_state(State::ReceiveTrailers, res, cx, buf)
210213
}
211214

212215
/// Receive trailers sent to the response, and store them in an internal
@@ -218,31 +221,32 @@ impl ChunkedEncoder {
218221
buf: &mut [u8],
219222
) -> Poll<io::Result<usize>> {
220223
// TODO: actually wait for trailers to be received.
221-
self.set_state(State::EncodeTrailers);
222-
self.encode_trailers(res, cx, buf)
224+
self.set_state(State::EncodeTrailers, res, cx, buf)
223225
}
224226

225227
/// Send trailers to the buffer.
226228
fn encode_trailers(
227229
&mut self,
228-
_res: &mut Response,
230+
res: &mut Response,
229231
cx: &mut Context<'_>,
230232
buf: &mut [u8],
231233
) -> Poll<io::Result<usize>> {
232234
// TODO: actually encode trailers here.
233-
self.set_state(State::EndOfStream);
234-
self.encode_eos(cx, buf)
235+
self.set_state(State::EndOfStream, res, cx, buf)
235236
}
236237

237238
/// Encode the end of the stream.
238-
fn encode_eos(&mut self, _cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
239+
fn encode_eos(
240+
&mut self,
241+
res: &mut Response,
242+
cx: &mut Context<'_>,
243+
buf: &mut [u8],
244+
) -> Poll<io::Result<usize>> {
239245
let idx = self.bytes_written;
240246
// Write the final CRLF
241247
buf[idx] = CR;
242248
buf[idx + 1] = LF;
243249
self.bytes_written += CRLF_LEN;
244-
245-
self.set_state(State::End);
246-
return Poll::Ready(Ok(self.bytes_written));
250+
self.set_state(State::End, res, cx, buf)
247251
}
248252
}

src/server/encode.rs

Lines changed: 50 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ use crate::date::fmt_http_date;
1515
/// This is returned from [`encode`].
1616
#[derive(Debug)]
1717
pub(crate) struct Encoder {
18+
/// The current level of recursion the encoder is in.
19+
depth: u16,
1820
/// HTTP headers to be sent.
1921
res: Response,
2022
/// The state of the encoding process
@@ -38,7 +40,7 @@ pub(crate) struct Encoder {
3840
#[derive(Debug)]
3941
enum State {
4042
/// Starting state.
41-
Init,
43+
Start,
4244
/// Write the HEAD section to an intermediate buffer.
4345
ComputeHead,
4446
/// Stream out the HEAD section.
@@ -51,12 +53,26 @@ enum State {
5153
End,
5254
}
5355

56+
impl Read for Encoder {
57+
fn poll_read(
58+
mut self: Pin<&mut Self>,
59+
cx: &mut Context<'_>,
60+
buf: &mut [u8],
61+
) -> Poll<io::Result<usize>> {
62+
self.bytes_written = 0;
63+
let res = self.exec(cx, buf);
64+
log::trace!("ServerEncoder {} bytes written", self.bytes_written);
65+
res
66+
}
67+
}
68+
5469
impl Encoder {
5570
/// Create a new instance of Encoder.
5671
pub(crate) fn new(res: Response) -> Self {
5772
Self {
5873
res,
59-
state: State::Init,
74+
depth: 0,
75+
state: State::Start,
6076
bytes_written: 0,
6177
head: vec![],
6278
head_bytes_written: 0,
@@ -66,46 +82,40 @@ impl Encoder {
6682
}
6783
}
6884

69-
pub(crate) fn encode(
85+
/// Switch the internal state to a new state.
86+
fn set_state(
7087
&mut self,
88+
state: State,
7189
cx: &mut Context<'_>,
7290
buf: &mut [u8],
7391
) -> Poll<io::Result<usize>> {
74-
self.bytes_written = 0;
75-
let res = match self.state {
76-
State::Init => self.init(cx, buf),
77-
State::ComputeHead => self.compute_head(cx, buf),
78-
State::EncodeHead => self.encode_head(cx, buf),
79-
State::EncodeFixedBody => self.encode_fixed_body(cx, buf),
80-
State::EncodeChunkedBody => self.encode_chunked_body(cx, buf),
81-
State::End => Poll::Ready(Ok(self.bytes_written)),
82-
};
83-
log::trace!("ServerEncoder {} bytes written", self.bytes_written);
84-
res
85-
}
86-
87-
/// Switch the internal state to a new state.
88-
fn set_state(&mut self, state: State) {
8992
use State::*;
90-
log::trace!("Server Encoder state: {:?} -> {:?}", self.state, state);
93+
log::trace!("ServerEncoder state: {:?} -> {:?}", self.state, state);
9194

9295
#[cfg(debug_assertions)]
9396
match self.state {
94-
Init => assert!(matches!(state, ComputeHead)),
97+
Start => assert!(matches!(state, ComputeHead)),
9598
ComputeHead => assert!(matches!(state, EncodeHead)),
9699
EncodeHead => assert!(matches!(state, EncodeChunkedBody | EncodeFixedBody)),
97100
EncodeFixedBody => assert!(matches!(state, End)),
98101
EncodeChunkedBody => assert!(matches!(state, End)),
99-
End => panic!("No state transitions allowed after the stream has ended"),
102+
End => panic!("No state transitions allowed after the ServerEncoder has ended"),
100103
}
101104

102105
self.state = state;
106+
self.exec(cx, buf)
103107
}
104108

105-
/// Initialize to the first state.
106-
fn init(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
107-
self.set_state(State::ComputeHead);
108-
self.compute_head(cx, buf)
109+
/// Execute the right method for the current state.
110+
fn exec(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
111+
match self.state {
112+
State::Start => self.set_state(State::ComputeHead, cx, buf),
113+
State::ComputeHead => self.compute_head(cx, buf),
114+
State::EncodeHead => self.encode_head(cx, buf),
115+
State::EncodeFixedBody => self.encode_fixed_body(cx, buf),
116+
State::EncodeChunkedBody => self.encode_chunked_body(cx, buf),
117+
State::End => Poll::Ready(Ok(self.bytes_written)),
118+
}
109119
}
110120

111121
/// Encode the headers to a buffer, the first time we poll.
@@ -142,8 +152,7 @@ impl Encoder {
142152

143153
std::io::Write::write_fmt(&mut self.head, format_args!("\r\n"))?;
144154

145-
self.set_state(State::EncodeHead);
146-
self.encode_head(cx, buf)
155+
self.set_state(State::EncodeHead, cx, buf)
147156
}
148157

149158
/// Encode the status code + headers.
@@ -164,18 +173,14 @@ impl Encoder {
164173
match self.res.len() {
165174
Some(body_len) => {
166175
self.body_len = body_len;
167-
self.state = State::EncodeFixedBody;
168-
return self.encode_fixed_body(cx, buf);
176+
self.set_state(State::EncodeFixedBody, cx, buf)
169177
}
170-
None => {
171-
self.state = State::EncodeChunkedBody;
172-
return self.encode_chunked_body(cx, buf);
173-
}
174-
};
178+
None => self.set_state(State::EncodeChunkedBody, cx, buf),
179+
}
175180
} else {
176181
// If we haven't read the entire header it means `buf` isn't
177182
// big enough. Break out of loop and return from `poll_read`
178-
return Poll::Ready(Ok(self.bytes_written));
183+
Poll::Ready(Ok(self.bytes_written))
179184
}
180185
}
181186

@@ -222,14 +227,13 @@ impl Encoder {
222227

223228
if self.body_len == self.body_bytes_written {
224229
// If we've read the `len` number of bytes, end
225-
self.set_state(State::End);
226-
return Poll::Ready(Ok(self.bytes_written));
230+
self.set_state(State::End, cx, buf)
227231
} else if new_body_bytes_written == 0 {
228232
// If we've reached unexpected EOF, end anyway
229233
// TODO: do something?
230-
self.set_state(State::End);
231-
return Poll::Ready(Ok(self.bytes_written));
234+
self.set_state(State::End, cx, buf)
232235
} else {
236+
// Else continue encoding
233237
self.encode_fixed_body(cx, buf)
234238
}
235239
}
@@ -245,28 +249,16 @@ impl Encoder {
245249
match self.chunked.encode(&mut self.res, cx, buf) {
246250
Poll::Ready(Ok(read)) => {
247251
self.bytes_written += read;
248-
if self.bytes_written == 0 {
249-
self.set_state(State::End);
252+
match self.bytes_written {
253+
0 => self.set_state(State::End, cx, buf),
254+
_ => Poll::Ready(Ok(self.bytes_written)),
250255
}
251-
Poll::Ready(Ok(self.bytes_written))
252256
}
253257
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
254-
Poll::Pending => {
255-
if self.bytes_written > 0 {
256-
return Poll::Ready(Ok(self.bytes_written));
257-
}
258-
Poll::Pending
259-
}
258+
Poll::Pending => match self.bytes_written {
259+
0 => Poll::Pending,
260+
_ => Poll::Ready(Ok(self.bytes_written)),
261+
},
260262
}
261263
}
262264
}
263-
264-
impl Read for Encoder {
265-
fn poll_read(
266-
mut self: Pin<&mut Self>,
267-
cx: &mut Context<'_>,
268-
buf: &mut [u8],
269-
) -> Poll<io::Result<usize>> {
270-
self.encode(cx, buf)
271-
}
272-
}

0 commit comments

Comments
 (0)