Skip to content

Commit a94097c

Browse files
committed
Move to state machine
1 parent e1c1b24 commit a94097c

File tree

1 file changed

+145
-94
lines changed

1 file changed

+145
-94
lines changed

src/server.rs

Lines changed: 145 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -84,34 +84,37 @@ where
8484
/// This is returned from [`encode`].
8585
#[derive(Debug)]
8686
struct Encoder {
87-
/// Keep track how far we've indexed into the headers + body.
88-
cursor: usize,
8987
/// HTTP headers to be sent.
90-
head: Option<Vec<u8>>,
91-
/// Check whether we're done sending headers.
92-
head_done: bool,
93-
/// Response containing the HTTP body to be sent.
9488
res: Response,
95-
/// Check whether we're done with the body.
96-
body_done: bool,
97-
/// Keep track of how many bytes have been read from the body stream.
98-
body_bytes_read: usize,
89+
/// The state of the encoding process
90+
state: EncoderState,
91+
}
92+
93+
#[derive(Debug)]
94+
enum EncoderState {
95+
Start,
96+
Head {
97+
data: Vec<u8>,
98+
head_bytes_read: usize,
99+
},
100+
Body {
101+
body_bytes_read: usize,
102+
body_len: usize,
103+
},
104+
Chunked,
105+
Done,
99106
}
100107

101108
impl Encoder {
102109
/// Create a new instance.
103110
pub(crate) fn encode(res: Response) -> Self {
104111
Self {
105112
res,
106-
head: None,
107-
cursor: 0,
108-
head_done: false,
109-
body_done: false,
110-
body_bytes_read: 0,
113+
state: EncoderState::Start,
111114
}
112115
}
113116

114-
fn encode_head(&mut self) -> io::Result<()> {
117+
fn encode_head(&self) -> io::Result<Vec<u8>> {
115118
let mut head: Vec<u8> = vec![];
116119

117120
let reason = self.res.status().canonical_reason();
@@ -138,8 +141,7 @@ impl Encoder {
138141

139142
std::io::Write::write_fmt(&mut head, format_args!("\r\n"))?;
140143

141-
self.head = Some(head);
142-
Ok(())
144+
Ok(head)
143145
}
144146
}
145147

@@ -149,105 +151,154 @@ impl Read for Encoder {
149151
cx: &mut Context<'_>,
150152
buf: &mut [u8],
151153
) -> Poll<io::Result<usize>> {
152-
// Encode the headers to a buffer, the first time we poll
153-
if let None = self.head {
154-
self.encode_head()?;
155-
}
156-
157154
// we must keep track how many bytes of the head and body we've read
158155
// in this call of `poll_read`
159-
let mut head_bytes_read = 0;
160-
let mut body_bytes_read = 0;
161-
162-
// Read from the serialized headers, url and methods.
163-
if !self.head_done {
164-
let head = self.head.as_ref().unwrap();
165-
let head_len = head.len();
166-
let len = std::cmp::min(head.len() - self.cursor, buf.len());
167-
let range = self.cursor..self.cursor + len;
168-
buf[0..len].copy_from_slice(&head[range]);
169-
self.cursor += len;
170-
if self.cursor == head_len {
171-
self.head_done = true;
172-
}
173-
head_bytes_read += len;
174-
}
175-
176-
// Read from the AsyncRead impl on the inner Response struct only if
177-
// done reading from the head.
178-
// We must ensure there's space to write at least 2 bytes into the
179-
// response stream.
180-
if self.head_done && !self.body_done && head_bytes_read <= buf.len() - 2 {
181-
// figure out how many bytes we can read. If a len was set, we need
182-
// to make sure we don't read more than that.
183-
let upper_bound = match self.res.len() {
184-
Some(len) => {
185-
debug_assert!(head_bytes_read == 0 || self.body_bytes_read == 0);
186-
(head_bytes_read + len - self.body_bytes_read).min(buf.len())
156+
let mut bytes_read = 0;
157+
loop {
158+
println!("{:?}", self.state);
159+
match self.state {
160+
EncoderState::Start => {
161+
// Encode the headers to a buffer, the first time we poll
162+
let head = self.encode_head()?;
163+
self.state = EncoderState::Head {
164+
data: head,
165+
head_bytes_read: 0,
166+
};
187167
}
188-
None => buf.len() - 2,
189-
};
190-
191-
match self.res.len() {
192-
Some(len) => {
168+
EncoderState::Head {
169+
ref data,
170+
mut head_bytes_read,
171+
} => {
172+
// Read from the serialized headers, url and methods.
173+
let head_len = data.len();
174+
let len = std::cmp::min(head_len - head_bytes_read, buf.len());
175+
let range = head_bytes_read..head_bytes_read + len;
176+
buf[0..len].copy_from_slice(&data[range]);
177+
bytes_read += len;
178+
head_bytes_read += len;
179+
180+
// If we've read the total length of the head we're done
181+
// reading the head and can transition to reading the body
182+
if head_bytes_read == head_len {
183+
// The response length lets us know if we are encoding
184+
// our body in chunks or now
185+
self.state = match self.res.len() {
186+
Some(body_len) => EncoderState::Body {
187+
body_bytes_read: 0,
188+
body_len,
189+
},
190+
None => EncoderState::Chunked,
191+
};
192+
}
193+
}
194+
EncoderState::Body {
195+
mut body_bytes_read,
196+
body_len,
197+
} => {
198+
// Double check that we didn't somehow read more bytes than
199+
// can fit in our buffer
200+
debug_assert!(bytes_read <= buf.len());
201+
202+
// ensure we have at least room for 1 more byte in our buffer
203+
if bytes_read == buf.len() {
204+
break;
205+
}
206+
// figure out how many bytes we can read
207+
let upper_bound = (bytes_read + body_len - body_bytes_read).min(buf.len());
193208
// Read bytes, and update internal tracking stuff.
194-
body_bytes_read = ready!(Pin::new(&mut self.res)
195-
.poll_read(cx, &mut buf[head_bytes_read..upper_bound]))?;
196-
209+
let new_body_bytes_read =
210+
ready!(Pin::new(&mut self.res)
211+
.poll_read(cx, &mut buf[bytes_read..upper_bound]))?;
212+
body_bytes_read += new_body_bytes_read;
213+
bytes_read += new_body_bytes_read;
214+
215+
// Double check we did not read more body bytes than the total
216+
// length of the body
197217
debug_assert!(
198-
self.body_bytes_read <= len,
218+
body_bytes_read <= body_len,
199219
"Too many bytes read. Expected: {}, read: {}",
200-
len,
201-
self.body_bytes_read
220+
body_len,
221+
body_bytes_read
202222
);
203223
// If we've read the `len` number of bytes or the stream no longer gives bytes, end.
204-
if len == self.body_bytes_read || body_bytes_read == 0 {
205-
self.body_done = true;
206-
}
224+
self.state = if body_len == body_bytes_read || body_bytes_read == 0 {
225+
EncoderState::Done
226+
} else {
227+
EncoderState::Body {
228+
body_bytes_read,
229+
body_len,
230+
}
231+
};
207232
}
208-
None => {
209-
let mut chunk_buf = vec![0; buf.len()];
233+
EncoderState::Chunked => {
234+
// ensure we have at least room for 1 more byte in our buffer
235+
if bytes_read == buf.len() {
236+
break;
237+
}
238+
239+
// We can read a maximum of the buffer's total size
240+
// minus what we've already filled the buffer with
241+
let buffer_remaining = buf.len() - bytes_read;
242+
// we must allocate a separate buffer for the chunk data
243+
// since we first need to know its length before writing
244+
// it into the actual buffer
245+
let mut chunk_buf = vec![0; buffer_remaining];
210246
// Read bytes from body reader
211-
let chunk_length = ready!(Pin::new(&mut self.res)
212-
.poll_read(cx, &mut chunk_buf[0..buf.len() - head_bytes_read]))?;
247+
let chunk_length =
248+
ready!(Pin::new(&mut self.res).poll_read(cx, &mut chunk_buf))?;
213249

214250
// serialize chunk length as hex
215251
let chunk_length_string = format!("{:X}", chunk_length);
216252
let chunk_length_bytes = chunk_length_string.as_bytes();
217253
let chunk_length_bytes_len = chunk_length_bytes.len();
218-
body_bytes_read += chunk_length_bytes_len;
219-
buf[head_bytes_read..head_bytes_read + body_bytes_read]
220-
.copy_from_slice(chunk_length_bytes);
221-
222-
// follow chunk length with CRLF
223-
buf[head_bytes_read + body_bytes_read] = b'\r';
224-
buf[head_bytes_read + body_bytes_read + 1] = b'\n';
225-
body_bytes_read += 2;
226-
227-
// copy chunk into buf
228-
buf[head_bytes_read + body_bytes_read
229-
..head_bytes_read + body_bytes_read + chunk_length]
230-
.copy_from_slice(&chunk_buf[..chunk_length]);
231-
body_bytes_read += chunk_length;
232-
233-
// TODO: relax this constraint at some point by adding extra state
234-
let bytes_read = head_bytes_read + body_bytes_read;
235-
assert!(buf.len() >= bytes_read + 7, "Buffers should have room for the head, the chunk length, 2 bytes, the chunk, and 4 extra bytes when using chunked encoding");
254+
const CRLF_LENGTH: usize = 2;
255+
256+
// calculate the total size of the chunk including serialized
257+
// length and the CRLF padding
258+
let total_chunk_size = bytes_read
259+
+ chunk_length_bytes_len
260+
+ CRLF_LENGTH
261+
+ chunk_length
262+
+ CRLF_LENGTH;
263+
264+
// See if we can write the chunk out in one go
265+
if total_chunk_size < buffer_remaining {
266+
// Write the chunk length into the buffer
267+
buf[bytes_read..bytes_read + chunk_length_bytes_len]
268+
.copy_from_slice(chunk_length_bytes);
269+
bytes_read += chunk_length_bytes_len;
270+
271+
// follow chunk length with CRLF
272+
buf[bytes_read] = b'\r';
273+
buf[bytes_read + 1] = b'\n';
274+
bytes_read += 2;
275+
276+
// copy chunk into buf
277+
buf[bytes_read..bytes_read + chunk_length]
278+
.copy_from_slice(&chunk_buf[..chunk_length]);
279+
bytes_read += chunk_length;
280+
281+
// follow chunk with CRLF
282+
buf[bytes_read] = b'\r';
283+
buf[bytes_read + 1] = b'\n';
284+
bytes_read += 2;
285+
} else {
286+
unimplemented!("TODO: handle when buf isn't big enough");
287+
}
288+
const EMPTY_CHUNK: &[u8; 5] = b"0\r\n\r\n";
236289

237-
buf[bytes_read..bytes_read + 7].copy_from_slice(b"\r\n0\r\n\r\n");
290+
buf[bytes_read..bytes_read + EMPTY_CHUNK.len()].copy_from_slice(EMPTY_CHUNK);
238291

239292
// if body_bytes_read == 0 {
240-
self.body_done = true;
241-
body_bytes_read += 7;
293+
bytes_read += 7;
294+
self.state = EncoderState::Done;
242295
// }
243296
}
297+
EncoderState::Done => break,
244298
}
245299
}
300+
println!("{:?}", std::str::from_utf8(&buf[0..bytes_read]));
246301

247-
self.body_bytes_read += body_bytes_read; // total body bytes read on all polls
248-
249-
// Return the total amount of bytes read.
250-
let bytes_read = head_bytes_read + body_bytes_read;
251302
Poll::Ready(Ok(bytes_read as usize))
252303
}
253304
}

0 commit comments

Comments
 (0)