|
1 |
| -use futures::io::{AsyncRead, AsyncWrite}; |
2 |
| -use rustls::Session; |
3 |
| -use std::io::{self, Read, Write}; |
4 |
| -use std::marker::Unpin; |
5 |
| -use std::pin::Pin; |
6 |
| -use std::task::{Context, Poll}; |
7 | 1 | pub(crate) mod tls_state;
|
8 |
| - |
9 |
| -pub struct Stream<'a, IO, S> { |
10 |
| - pub io: &'a mut IO, |
11 |
| - pub session: &'a mut S, |
12 |
| - pub eof: bool, |
13 |
| -} |
14 |
| - |
15 |
| -trait WriteTls<IO: AsyncWrite, S: Session> { |
16 |
| - fn write_tls(&mut self, cx: &mut Context) -> io::Result<usize>; |
17 |
| -} |
18 |
| - |
19 |
| -#[derive(Clone, Copy)] |
20 |
| -enum Focus { |
21 |
| - Empty, |
22 |
| - Readable, |
23 |
| - Writable, |
24 |
| -} |
25 |
| - |
26 |
| -impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> { |
27 |
| - pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { |
28 |
| - Stream { |
29 |
| - io, |
30 |
| - session, |
31 |
| - // The state so far is only used to detect EOF, so either Stream |
32 |
| - // or EarlyData state should both be all right. |
33 |
| - eof: false, |
34 |
| - } |
35 |
| - } |
36 |
| - |
37 |
| - pub fn set_eof(mut self, eof: bool) -> Self { |
38 |
| - self.eof = eof; |
39 |
| - self |
40 |
| - } |
41 |
| - |
42 |
| - pub fn as_mut_pin(&mut self) -> Pin<&mut Self> { |
43 |
| - Pin::new(self) |
44 |
| - } |
45 |
| - |
46 |
| - pub fn complete_io(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> { |
47 |
| - self.complete_inner_io(cx, Focus::Empty) |
48 |
| - } |
49 |
| - |
50 |
| - fn complete_read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> { |
51 |
| - struct Reader<'a, 'b, T> { |
52 |
| - io: &'a mut T, |
53 |
| - cx: &'a mut Context<'b>, |
54 |
| - } |
55 |
| - |
56 |
| - impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> { |
57 |
| - fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { |
58 |
| - match Pin::new(&mut self.io).poll_read(self.cx, buf) { |
59 |
| - Poll::Ready(result) => result, |
60 |
| - Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), |
61 |
| - } |
62 |
| - } |
63 |
| - } |
64 |
| - |
65 |
| - let mut reader = Reader { io: self.io, cx }; |
66 |
| - |
67 |
| - let n = match self.session.read_tls(&mut reader) { |
68 |
| - Ok(n) => n, |
69 |
| - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, |
70 |
| - Err(err) => return Poll::Ready(Err(err)), |
71 |
| - }; |
72 |
| - |
73 |
| - self.session.process_new_packets().map_err(|err| { |
74 |
| - // In case we have an alert to send describing this error, |
75 |
| - // try a last-gasp write -- but don't predate the primary |
76 |
| - // error. |
77 |
| - let _ = self.write_tls(cx); |
78 |
| - |
79 |
| - io::Error::new(io::ErrorKind::InvalidData, err) |
80 |
| - })?; |
81 |
| - |
82 |
| - Poll::Ready(Ok(n)) |
83 |
| - } |
84 |
| - |
85 |
| - fn complete_write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> { |
86 |
| - match self.write_tls(cx) { |
87 |
| - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, |
88 |
| - result => Poll::Ready(result), |
89 |
| - } |
90 |
| - } |
91 |
| - |
92 |
| - fn complete_inner_io( |
93 |
| - &mut self, |
94 |
| - cx: &mut Context, |
95 |
| - focus: Focus, |
96 |
| - ) -> Poll<io::Result<(usize, usize)>> { |
97 |
| - let mut wrlen = 0; |
98 |
| - let mut rdlen = 0; |
99 |
| - |
100 |
| - loop { |
101 |
| - let mut write_would_block = false; |
102 |
| - let mut read_would_block = false; |
103 |
| - |
104 |
| - while self.session.wants_write() { |
105 |
| - match self.complete_write_io(cx) { |
106 |
| - Poll::Ready(Ok(n)) => wrlen += n, |
107 |
| - Poll::Pending => { |
108 |
| - write_would_block = true; |
109 |
| - break; |
110 |
| - } |
111 |
| - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), |
112 |
| - } |
113 |
| - } |
114 |
| - |
115 |
| - if !self.eof && self.session.wants_read() { |
116 |
| - match self.complete_read_io(cx) { |
117 |
| - Poll::Ready(Ok(0)) => self.eof = true, |
118 |
| - Poll::Ready(Ok(n)) => rdlen += n, |
119 |
| - Poll::Pending => read_would_block = true, |
120 |
| - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), |
121 |
| - } |
122 |
| - } |
123 |
| - |
124 |
| - let would_block = match focus { |
125 |
| - Focus::Empty => write_would_block || read_would_block, |
126 |
| - Focus::Readable => read_would_block, |
127 |
| - Focus::Writable => write_would_block, |
128 |
| - }; |
129 |
| - |
130 |
| - match (self.eof, self.session.is_handshaking(), would_block) { |
131 |
| - (true, true, _) => { |
132 |
| - let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); |
133 |
| - return Poll::Ready(Err(err)); |
134 |
| - } |
135 |
| - (_, false, true) => { |
136 |
| - let would_block = match focus { |
137 |
| - Focus::Empty => rdlen == 0 && wrlen == 0, |
138 |
| - Focus::Readable => rdlen == 0, |
139 |
| - Focus::Writable => wrlen == 0, |
140 |
| - }; |
141 |
| - |
142 |
| - return if would_block { |
143 |
| - Poll::Pending |
144 |
| - } else { |
145 |
| - Poll::Ready(Ok((rdlen, wrlen))) |
146 |
| - }; |
147 |
| - } |
148 |
| - (_, false, _) => return Poll::Ready(Ok((rdlen, wrlen))), |
149 |
| - (_, true, true) => return Poll::Pending, |
150 |
| - (..) => (), |
151 |
| - } |
152 |
| - } |
153 |
| - } |
154 |
| -} |
155 |
| - |
156 |
| -impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> WriteTls<IO, S> for Stream<'a, IO, S> { |
157 |
| - fn write_tls(&mut self, cx: &mut Context) -> io::Result<usize> { |
158 |
| - // TODO writev |
159 |
| - |
160 |
| - struct Writer<'a, 'b, T> { |
161 |
| - io: &'a mut T, |
162 |
| - cx: &'a mut Context<'b>, |
163 |
| - } |
164 |
| - |
165 |
| - impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> { |
166 |
| - fn write(&mut self, buf: &[u8]) -> io::Result<usize> { |
167 |
| - match Pin::new(&mut self.io).poll_write(self.cx, buf) { |
168 |
| - Poll::Ready(result) => result, |
169 |
| - Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), |
170 |
| - } |
171 |
| - } |
172 |
| - |
173 |
| - fn flush(&mut self) -> io::Result<()> { |
174 |
| - match Pin::new(&mut self.io).poll_flush(self.cx) { |
175 |
| - Poll::Ready(result) => result, |
176 |
| - Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), |
177 |
| - } |
178 |
| - } |
179 |
| - } |
180 |
| - |
181 |
| - let mut writer = Writer { io: self.io, cx }; |
182 |
| - self.session.write_tls(&mut writer) |
183 |
| - } |
184 |
| -} |
185 |
| - |
186 |
| -impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> { |
187 |
| - fn poll_read( |
188 |
| - self: Pin<&mut Self>, |
189 |
| - cx: &mut Context, |
190 |
| - buf: &mut [u8], |
191 |
| - ) -> Poll<io::Result<usize>> { |
192 |
| - let this = self.get_mut(); |
193 |
| - |
194 |
| - while this.session.wants_read() { |
195 |
| - match this.complete_inner_io(cx, Focus::Readable) { |
196 |
| - Poll::Ready(Ok((0, _))) => break, |
197 |
| - Poll::Ready(Ok(_)) => (), |
198 |
| - Poll::Pending => return Poll::Pending, |
199 |
| - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), |
200 |
| - } |
201 |
| - } |
202 |
| - |
203 |
| - match this.session.read(buf) { |
204 |
| - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, |
205 |
| - result => Poll::Ready(result), |
206 |
| - } |
207 |
| - } |
208 |
| -} |
209 |
| - |
210 |
| -impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> { |
211 |
| - fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> { |
212 |
| - let this = self.get_mut(); |
213 |
| - |
214 |
| - let len = match this.session.write(buf) { |
215 |
| - Ok(n) => n, |
216 |
| - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, |
217 |
| - Err(err) => return Poll::Ready(Err(err)), |
218 |
| - }; |
219 |
| - while this.session.wants_write() { |
220 |
| - match this.complete_inner_io(cx, Focus::Writable) { |
221 |
| - Poll::Ready(Ok(_)) => (), |
222 |
| - Poll::Pending if len != 0 => break, |
223 |
| - Poll::Pending => return Poll::Pending, |
224 |
| - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), |
225 |
| - } |
226 |
| - } |
227 |
| - |
228 |
| - if len != 0 || buf.is_empty() { |
229 |
| - Poll::Ready(Ok(len)) |
230 |
| - } else { |
231 |
| - // not write zero |
232 |
| - match this.session.write(buf) { |
233 |
| - Ok(0) => Poll::Pending, |
234 |
| - Ok(n) => Poll::Ready(Ok(n)), |
235 |
| - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, |
236 |
| - Err(err) => Poll::Ready(Err(err)), |
237 |
| - } |
238 |
| - } |
239 |
| - } |
240 |
| - |
241 |
| - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> { |
242 |
| - let this = self.get_mut(); |
243 |
| - |
244 |
| - this.session.flush()?; |
245 |
| - while this.session.wants_write() { |
246 |
| - futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; |
247 |
| - } |
248 |
| - Pin::new(&mut this.io).poll_flush(cx) |
249 |
| - } |
250 |
| - |
251 |
| - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
252 |
| - let this = self.get_mut(); |
253 |
| - |
254 |
| - while this.session.wants_write() { |
255 |
| - futures::ready!(this.complete_inner_io(cx, Focus::Writable))?; |
256 |
| - } |
257 |
| - Pin::new(&mut this.io).poll_close(cx) |
258 |
| - } |
259 |
| -} |
260 |
| - |
261 |
| -#[cfg(test)] |
262 |
| -mod test_stream; |
0 commit comments