Skip to content

Commit 9b6b1b9

Browse files
sdroegedignifiedquire
authored andcommitted
feat: make connect return a Future directly
Don't return a io::Result<impl Future<_>> from Connector::connect() but only a plain future. `let conn = connector.connect(...)?.await?` is a bit awkward to use because of the double ?-operator. Instead always return a future now that in case of domain format errors immediately resolves to the error. Closes #10
1 parent 120e379 commit 9b6b1b9

File tree

6 files changed

+34
-27
lines changed

6 files changed

+34
-27
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ use async_std::net::TcpStream;
5858

5959
let tcp_stream = TcpStream::connect("rust-lang.org:443").await?;
6060
let connector = TlsConnector::default();
61-
let handshake = connector.connect("www.rust-lang.org", tcp_stream)?;
62-
let mut tls_stream = handshake.await?;
61+
let mut tls_stream = connector.connect("www.rust-lang.org", tcp_stream).await?;
6362

6463
// ...
6564
```

examples/client/src/main.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,10 @@ fn main() -> io::Result<()> {
6363
let tcp_stream = TcpStream::connect(&addr).await?;
6464

6565
// Use the connector to start the handshake process.
66-
// This might fail early if you pass an invalid domain,
67-
// which is why we use `?`.
6866
// This consumes the TCP stream to ensure you are not reusing it.
69-
let handshake = connector.connect(&domain, tcp_stream)?;
7067
// Awaiting the handshake gives you an encrypted
7168
// stream back which you can use like any other.
72-
let mut tls_stream = handshake.await?;
69+
let mut tls_stream = connector.connect(&domain, tcp_stream).await?;
7370

7471
// We write our crafted HTTP request to it
7572
tls_stream.write_all(http_request.as_bytes()).await?;

src/connector.rs

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ use webpki::DNSNameRef;
3030
/// async_std::task::block_on(async {
3131
/// let connector = TlsConnector::default();
3232
/// let tcp_stream = async_std::net::TcpStream::connect("example.com").await?;
33-
/// let handshake = connector.connect("example.com", tcp_stream)?;
34-
/// let encrypted_stream = handshake.await?;
33+
/// let encrypted_stream = connector.connect("example.com", tcp_stream).await?;
3534
///
3635
/// Ok(()) as async_std::io::Result<()>
3736
/// });
@@ -83,11 +82,10 @@ impl TlsConnector {
8382
/// Connect to a server. `stream` can be any type implementing `AsyncRead` and `AsyncWrite`,
8483
/// such as TcpStreams or Unix domain sockets.
8584
///
86-
/// The function will return an error if the domain is not of valid format.
87-
/// Otherwise, it will return a `Connect` Future, representing the connecting part of a
88-
/// Tls handshake. It will resolve when the handshake is over.
85+
/// The function will return a `Connect` Future, representing the connecting part of a Tls
86+
/// handshake. It will resolve when the handshake is over.
8987
#[inline]
90-
pub fn connect<'a, IO>(&self, domain: impl AsRef<str>, stream: IO) -> io::Result<Connect<IO>>
88+
pub fn connect<'a, IO>(&self, domain: impl AsRef<str>, stream: IO) -> Connect<IO>
9189
where
9290
IO: AsyncRead + AsyncWrite + Unpin,
9391
{
@@ -96,24 +94,27 @@ impl TlsConnector {
9694

9795
// NOTE: Currently private, exposing ClientSession exposes rusttls
9896
// Early data should be exposed differently
99-
fn connect_with<'a, IO, F>(
100-
&self,
101-
domain: impl AsRef<str>,
102-
stream: IO,
103-
f: F,
104-
) -> io::Result<Connect<IO>>
97+
fn connect_with<'a, IO, F>(&self, domain: impl AsRef<str>, stream: IO, f: F) -> Connect<IO>
10598
where
10699
IO: AsyncRead + AsyncWrite + Unpin,
107100
F: FnOnce(&mut ClientSession),
108101
{
109-
let domain = DNSNameRef::try_from_ascii_str(domain.as_ref())
110-
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid domain"))?;
102+
let domain = match DNSNameRef::try_from_ascii_str(domain.as_ref()) {
103+
Ok(domain) => domain,
104+
Err(_) => {
105+
return Connect(ConnectInner::Error(Some(io::Error::new(
106+
io::ErrorKind::InvalidInput,
107+
"invalid domain",
108+
))))
109+
}
110+
};
111+
111112
let mut session = ClientSession::new(&self.inner, domain);
112113
f(&mut session);
113114

114115
#[cfg(not(feature = "early-data"))]
115116
{
116-
Ok(Connect(client::MidHandshake::Handshaking(
117+
Connect(ConnectInner::Handshake(client::MidHandshake::Handshaking(
117118
client::TlsStream {
118119
session,
119120
io: stream,
@@ -124,7 +125,7 @@ impl TlsConnector {
124125

125126
#[cfg(feature = "early-data")]
126127
{
127-
Ok(Connect(if self.early_data {
128+
Connect(ConnectInner::Handshake(if self.early_data {
128129
client::MidHandshake::EarlyData(client::TlsStream {
129130
session,
130131
io: stream,
@@ -145,13 +146,23 @@ impl TlsConnector {
145146

146147
/// Future returned from `TlsConnector::connect` which will resolve
147148
/// once the connection handshake has finished.
148-
pub struct Connect<IO>(client::MidHandshake<IO>);
149+
pub struct Connect<IO>(ConnectInner<IO>);
150+
151+
enum ConnectInner<IO> {
152+
Error(Option<io::Error>),
153+
Handshake(client::MidHandshake<IO>),
154+
}
149155

150156
impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
151157
type Output = io::Result<client::TlsStream<IO>>;
152158

153159
#[inline]
154160
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
155-
Pin::new(&mut self.0).poll(cx)
161+
match self.0 {
162+
ConnectInner::Error(ref mut err) => {
163+
Poll::Ready(Err(err.take().expect("Polled twice after being Ready")))
164+
}
165+
ConnectInner::Handshake(ref mut handshake) => Pin::new(handshake).poll(cx),
166+
}
156167
}
157168
}

src/test_0rtt.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ async fn get(
1919
let mut buf = Vec::new();
2020

2121
let stream = TcpStream::connect(&addr).await?;
22-
let mut stream = connector.connect(domain, stream)?.await?;
22+
let mut stream = connector.connect(domain, stream).await?;
2323
stream.write_all(input.as_bytes()).await?;
2424
stream.read_to_end(&mut buf).await?;
2525

tests/google.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ fn fetch_google() -> std::io::Result<()> {
99
let connector = TlsConnector::default();
1010

1111
let stream = TcpStream::connect("google.com:443").await?;
12-
let mut stream = connector.connect("google.com", stream)?.await?;
12+
let mut stream = connector.connect("google.com", stream).await?;
1313

1414
stream.write_all(b"GET / HTTP/1.0\r\n\r\n").await?;
1515
let mut res = vec![];

tests/test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ async fn start_client(addr: SocketAddr, domain: &str, config: Arc<ClientConfig>)
6767
let mut buf = vec![0; FILE.len()];
6868

6969
let stream = TcpStream::connect(&addr).await?;
70-
let mut stream = config.connect(domain, stream)?.await?;
70+
let mut stream = config.connect(domain, stream).await?;
7171
stream.write_all(FILE).await?;
7272
stream.read_exact(&mut buf).await?;
7373

0 commit comments

Comments
 (0)