@@ -30,8 +30,7 @@ use webpki::DNSNameRef;
30
30
/// async_std::task::block_on(async {
31
31
/// let connector = TlsConnector::default();
32
32
/// let tcp_stream = async_std::net::TcpStream::connect("example.com").await?;
33
- /// let handshake = connector.connect("example.com", tcp_stream)?;
34
- /// let encrypted_stream = handshake.await?;
33
+ /// let encrypted_stream = connector.connect("example.com", tcp_stream).await?;
35
34
///
36
35
/// Ok(()) as async_std::io::Result<()>
37
36
/// });
@@ -83,11 +82,10 @@ impl TlsConnector {
83
82
/// Connect to a server. `stream` can be any type implementing `AsyncRead` and `AsyncWrite`,
84
83
/// such as TcpStreams or Unix domain sockets.
85
84
///
86
- /// The function will return an error if the domain is not of valid format.
87
- /// Otherwise, it will return a `Connect` Future, representing the connecting part of a
88
- /// Tls handshake. It will resolve when the handshake is over.
85
+ /// The function will return a `Connect` Future, representing the connecting part of a Tls
86
+ /// handshake. It will resolve when the handshake is over.
89
87
#[ inline]
90
- pub fn connect < ' a , IO > ( & self , domain : impl AsRef < str > , stream : IO ) -> io :: Result < Connect < IO > >
88
+ pub fn connect < ' a , IO > ( & self , domain : impl AsRef < str > , stream : IO ) -> Connect < IO >
91
89
where
92
90
IO : AsyncRead + AsyncWrite + Unpin ,
93
91
{
@@ -96,24 +94,27 @@ impl TlsConnector {
96
94
97
95
// NOTE: Currently private, exposing ClientSession exposes rusttls
98
96
// Early data should be exposed differently
99
- fn connect_with < ' a , IO , F > (
100
- & self ,
101
- domain : impl AsRef < str > ,
102
- stream : IO ,
103
- f : F ,
104
- ) -> io:: Result < Connect < IO > >
97
+ fn connect_with < ' a , IO , F > ( & self , domain : impl AsRef < str > , stream : IO , f : F ) -> Connect < IO >
105
98
where
106
99
IO : AsyncRead + AsyncWrite + Unpin ,
107
100
F : FnOnce ( & mut ClientSession ) ,
108
101
{
109
- let domain = DNSNameRef :: try_from_ascii_str ( domain. as_ref ( ) )
110
- . map_err ( |_| io:: Error :: new ( io:: ErrorKind :: InvalidInput , "invalid domain" ) ) ?;
102
+ let domain = match DNSNameRef :: try_from_ascii_str ( domain. as_ref ( ) ) {
103
+ Ok ( domain) => domain,
104
+ Err ( _) => {
105
+ return Connect ( ConnectInner :: Error ( Some ( io:: Error :: new (
106
+ io:: ErrorKind :: InvalidInput ,
107
+ "invalid domain" ,
108
+ ) ) ) )
109
+ }
110
+ } ;
111
+
111
112
let mut session = ClientSession :: new ( & self . inner , domain) ;
112
113
f ( & mut session) ;
113
114
114
115
#[ cfg( not( feature = "early-data" ) ) ]
115
116
{
116
- Ok ( Connect ( client:: MidHandshake :: Handshaking (
117
+ Connect ( ConnectInner :: Handshake ( client:: MidHandshake :: Handshaking (
117
118
client:: TlsStream {
118
119
session,
119
120
io : stream,
@@ -124,7 +125,7 @@ impl TlsConnector {
124
125
125
126
#[ cfg( feature = "early-data" ) ]
126
127
{
127
- Ok ( Connect ( if self . early_data {
128
+ Connect ( ConnectInner :: Handshake ( if self . early_data {
128
129
client:: MidHandshake :: EarlyData ( client:: TlsStream {
129
130
session,
130
131
io : stream,
@@ -145,13 +146,23 @@ impl TlsConnector {
145
146
146
147
/// Future returned from `TlsConnector::connect` which will resolve
147
148
/// once the connection handshake has finished.
148
- pub struct Connect < IO > ( client:: MidHandshake < IO > ) ;
149
+ pub struct Connect < IO > ( ConnectInner < IO > ) ;
150
+
151
+ enum ConnectInner < IO > {
152
+ Error ( Option < io:: Error > ) ,
153
+ Handshake ( client:: MidHandshake < IO > ) ,
154
+ }
149
155
150
156
impl < IO : AsyncRead + AsyncWrite + Unpin > Future for Connect < IO > {
151
157
type Output = io:: Result < client:: TlsStream < IO > > ;
152
158
153
159
#[ inline]
154
160
fn poll ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
155
- Pin :: new ( & mut self . 0 ) . poll ( cx)
161
+ match self . 0 {
162
+ ConnectInner :: Error ( ref mut err) => {
163
+ Poll :: Ready ( Err ( err. take ( ) . expect ( "Polled twice after being Ready" ) ) )
164
+ }
165
+ ConnectInner :: Handshake ( ref mut handshake) => Pin :: new ( handshake) . poll ( cx) ,
166
+ }
156
167
}
157
168
}
0 commit comments