@@ -19,7 +19,7 @@ use rustls::{
19
19
use crate :: error:: Error ;
20
20
use crate :: io:: ReadBuf ;
21
21
use crate :: net:: tls:: util:: StdSocket ;
22
- use crate :: net:: tls:: TlsConfig ;
22
+ use crate :: net:: tls:: { RawTlsConfig , TlsConfig } ;
23
23
use crate :: net:: Socket ;
24
24
25
25
pub struct RustlsSocket < S : Socket > {
@@ -87,100 +87,125 @@ impl<S: Socket> Socket for RustlsSocket<S> {
87
87
}
88
88
}
89
89
90
- pub async fn handshake < S > ( socket : S , tls_config : TlsConfig < ' _ > ) -> Result < RustlsSocket < S > , Error >
91
- where
92
- S : Socket ,
93
- {
94
- #[ cfg( all(
95
- feature = "_tls-rustls-aws-lc-rs" ,
96
- not( feature = "_tls-rustls-ring-webpki" ) ,
97
- not( feature = "_tls-rustls-ring-native-roots" )
98
- ) ) ]
99
- let provider = Arc :: new ( rustls:: crypto:: aws_lc_rs:: default_provider ( ) ) ;
100
- #[ cfg( any(
101
- feature = "_tls-rustls-ring-webpki" ,
102
- feature = "_tls-rustls-ring-native-roots"
103
- ) ) ]
104
- let provider = Arc :: new ( rustls:: crypto:: ring:: default_provider ( ) ) ;
105
-
106
- // Unwrapping is safe here because we use a default provider.
107
- let config = ClientConfig :: builder_with_provider ( provider. clone ( ) )
108
- . with_safe_default_protocol_versions ( )
109
- . unwrap ( ) ;
110
-
111
- // authentication using user's key and its associated certificate
112
- let user_auth = match ( tls_config. client_cert_path , tls_config. client_key_path ) {
113
- ( Some ( cert_path) , Some ( key_path) ) => {
114
- let cert_chain = certs_from_pem ( cert_path. data ( ) . await ?) ?;
115
- let key_der = private_key_from_pem ( key_path. data ( ) . await ?) ?;
116
- Some ( ( cert_chain, key_der) )
117
- }
118
- ( None , None ) => None ,
119
- ( _, _) => {
120
- return Err ( Error :: Configuration (
121
- "user auth key and certs must be given together" . into ( ) ,
122
- ) )
123
- }
124
- } ;
90
+ impl TlsConfig < ' _ > {
91
+ async fn rustls_config ( & self ) -> crate :: Result < ( rustls:: ClientConfig , & str ) , Error > {
92
+ let RawTlsConfig {
93
+ accept_invalid_certs,
94
+ accept_invalid_hostnames,
95
+ hostname,
96
+ root_cert,
97
+ client_cert,
98
+ client_key,
99
+ } = match self {
100
+ TlsConfig :: RawTlsConfig ( raw) => raw,
101
+ TlsConfig :: PrebuiltRustls { config, hostname } => {
102
+ return Ok ( ( ( * config) . to_owned ( ) , hostname) ) ;
103
+ }
104
+ } ;
105
+
106
+ #[ cfg( all(
107
+ feature = "_tls-rustls-aws-lc-rs" ,
108
+ not( feature = "_tls-rustls-ring-webpki" ) ,
109
+ not( feature = "_tls-rustls-ring-native-roots" )
110
+ ) ) ]
111
+ let provider = Arc :: new ( rustls:: crypto:: aws_lc_rs:: default_provider ( ) ) ;
112
+ #[ cfg( any(
113
+ feature = "_tls-rustls-ring-webpki" ,
114
+ feature = "_tls-rustls-ring-native-roots"
115
+ ) ) ]
116
+ let provider = Arc :: new ( rustls:: crypto:: ring:: default_provider ( ) ) ;
117
+
118
+ // Unwrapping is safe here because we use a default provider.
119
+ let config = ClientConfig :: builder_with_provider ( provider. clone ( ) )
120
+ . with_safe_default_protocol_versions ( )
121
+ . unwrap ( ) ;
122
+
123
+ // authentication using user's key and its associated certificate
124
+ let user_auth = match ( client_cert, client_key) {
125
+ ( Some ( cert) , Some ( key) ) => {
126
+ let cert_chain = certs_from_pem ( cert. data ( ) . await ?) ?;
127
+ let key_der = private_key_from_pem ( key. data ( ) . await ?) ?;
128
+ Some ( ( cert_chain, key_der) )
129
+ }
130
+ ( None , None ) => None ,
131
+ ( _, _) => {
132
+ return Err ( Error :: Configuration (
133
+ "user auth key and certs must be given together" . into ( ) ,
134
+ ) )
135
+ }
136
+ } ;
125
137
126
- let config = if tls_config. accept_invalid_certs {
127
- if let Some ( user_auth) = user_auth {
128
- config
129
- . dangerous ( )
130
- . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
131
- . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
132
- . map_err ( Error :: tls) ?
138
+ let config = if * accept_invalid_certs {
139
+ if let Some ( user_auth) = user_auth {
140
+ config
141
+ . dangerous ( )
142
+ . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
143
+ . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
144
+ . map_err ( Error :: tls) ?
145
+ } else {
146
+ config
147
+ . dangerous ( )
148
+ . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
149
+ . with_no_client_auth ( )
150
+ }
133
151
} else {
134
- config
135
- . dangerous ( )
136
- . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
137
- . with_no_client_auth ( )
138
- }
139
- } else {
140
- let mut cert_store = import_root_certs ( ) ;
152
+ let mut cert_store = import_root_certs ( ) ;
141
153
142
- if let Some ( ca) = tls_config . root_cert_path {
143
- let data = ca. data ( ) . await ?;
154
+ if let Some ( ca) = root_cert {
155
+ let data = ca. data ( ) . await ?;
144
156
145
- for result in CertificateDer :: pem_slice_iter ( & data) {
146
- let Ok ( cert) = result else {
147
- return Err ( Error :: Tls ( format ! ( "Invalid certificate {ca}" ) . into ( ) ) ) ;
148
- } ;
157
+ for result in CertificateDer :: pem_slice_iter ( & data) {
158
+ let Ok ( cert) = result else {
159
+ return Err ( Error :: Tls ( format ! ( "Invalid certificate {ca}" ) . into ( ) ) ) ;
160
+ } ;
149
161
150
- cert_store. add ( cert) . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
162
+ cert_store. add ( cert) . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
163
+ }
151
164
}
152
- }
153
-
154
- if tls_config. accept_invalid_hostnames {
155
- let verifier = WebPkiServerVerifier :: builder ( Arc :: new ( cert_store) )
156
- . build ( )
157
- . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
158
165
159
- if let Some ( user_auth) = user_auth {
166
+ if * accept_invalid_hostnames {
167
+ let verifier = WebPkiServerVerifier :: builder ( Arc :: new ( cert_store) )
168
+ . build ( )
169
+ . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
170
+
171
+ if let Some ( user_auth) = user_auth {
172
+ config
173
+ . dangerous ( )
174
+ . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier {
175
+ verifier,
176
+ } ) )
177
+ . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
178
+ . map_err ( Error :: tls) ?
179
+ } else {
180
+ config
181
+ . dangerous ( )
182
+ . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier {
183
+ verifier,
184
+ } ) )
185
+ . with_no_client_auth ( )
186
+ }
187
+ } else if let Some ( user_auth) = user_auth {
160
188
config
161
- . dangerous ( )
162
- . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier { verifier } ) )
189
+ . with_root_certificates ( cert_store)
163
190
. with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
164
191
. map_err ( Error :: tls) ?
165
192
} else {
166
193
config
167
- . dangerous ( )
168
- . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier { verifier } ) )
194
+ . with_root_certificates ( cert_store)
169
195
. with_no_client_auth ( )
170
196
}
171
- } else if let Some ( user_auth) = user_auth {
172
- config
173
- . with_root_certificates ( cert_store)
174
- . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
175
- . map_err ( Error :: tls) ?
176
- } else {
177
- config
178
- . with_root_certificates ( cert_store)
179
- . with_no_client_auth ( )
180
- }
181
- } ;
197
+ } ;
198
+
199
+ Ok ( ( config, hostname) )
200
+ }
201
+ }
182
202
183
- let host = ServerName :: try_from ( tls_config. hostname . to_owned ( ) ) . map_err ( Error :: tls) ?;
203
+ pub async fn handshake < S > ( socket : S , tls_config : TlsConfig < ' _ > ) -> Result < RustlsSocket < S > , Error >
204
+ where
205
+ S : Socket ,
206
+ {
207
+ let ( config, hostname) = tls_config. rustls_config ( ) . await ?;
208
+ let host = ServerName :: try_from ( hostname. to_owned ( ) ) . map_err ( Error :: tls) ?;
184
209
185
210
let mut socket = RustlsSocket {
186
211
inner : StdSocket :: new ( socket) ,
0 commit comments