|
3 | 3 | use std::time::Duration;
|
4 | 4 |
|
5 | 5 | use async_std::future::{timeout, Future, TimeoutError};
|
6 |
| -use async_std::io::{self}; |
7 |
| -use async_std::io::{Read, Write}; |
8 |
| -use http_types::{Request, Response}; |
| 6 | +use async_std::io::{self, Read, Write}; |
| 7 | +use http_types::headers::{CONNECTION, UPGRADE}; |
| 8 | +use http_types::upgrade::Connection; |
| 9 | +use http_types::{Request, Response, StatusCode}; |
9 | 10 |
|
10 | 11 | mod decode;
|
11 | 12 | mod encode;
|
@@ -70,14 +71,32 @@ where
|
70 | 71 | }
|
71 | 72 | };
|
72 | 73 |
|
| 74 | + let upgrade_requested = match (req.header(UPGRADE), req.header(CONNECTION)) { |
| 75 | + (Some(_), Some(upgrade)) if upgrade.as_str().eq_ignore_ascii_case("upgrade") => true, |
| 76 | + _ => false, |
| 77 | + }; |
| 78 | + |
73 | 79 | let method = req.method();
|
| 80 | + |
74 | 81 | // Pass the request to the endpoint and encode the response.
|
75 |
| - let res = endpoint(req).await?; |
| 82 | + let mut res = endpoint(req).await?; |
| 83 | + |
| 84 | + let upgrade_provided = res.status() == StatusCode::SwitchingProtocols && res.has_upgrade(); |
| 85 | + |
| 86 | + let upgrade_sender = if upgrade_requested && upgrade_provided { |
| 87 | + Some(res.send_upgrade()) |
| 88 | + } else { |
| 89 | + None |
| 90 | + }; |
76 | 91 |
|
77 | 92 | let mut encoder = Encoder::new(res, method);
|
78 | 93 |
|
79 | 94 | // Stream the response to the writer.
|
80 | 95 | io::copy(&mut encoder, &mut io).await?;
|
| 96 | + |
| 97 | + if let Some(upgrade_sender) = upgrade_sender { |
| 98 | + upgrade_sender.send(Connection::new(io.clone())).await; |
| 99 | + } |
81 | 100 | }
|
82 | 101 |
|
83 | 102 | Ok(())
|
|
0 commit comments