@@ -19,7 +19,7 @@ use rustls::{
19
19
use crate :: error:: Error ;
20
20
use crate :: io:: ReadBuf ;
21
21
use crate :: net:: tls:: util:: StdSocket ;
22
- use crate :: net:: tls:: TlsConfig ;
22
+ use crate :: net:: tls:: { RawTlsConfig , TlsConfig } ;
23
23
use crate :: net:: Socket ;
24
24
25
25
pub struct RustlsSocket < S : Socket > {
@@ -87,100 +87,134 @@ impl<S: Socket> Socket for RustlsSocket<S> {
87
87
}
88
88
}
89
89
90
- pub async fn handshake < S > ( socket : S , tls_config : TlsConfig < ' _ > ) -> Result < RustlsSocket < S > , Error >
91
- where
92
- S : Socket ,
93
- {
94
- #[ cfg( all(
95
- feature = "_tls-rustls-aws-lc-rs" ,
96
- not( feature = "_tls-rustls-ring-webpki" ) ,
97
- not( feature = "_tls-rustls-ring-native-roots" )
98
- ) ) ]
99
- let provider = Arc :: new ( rustls:: crypto:: aws_lc_rs:: default_provider ( ) ) ;
100
- #[ cfg( any(
101
- feature = "_tls-rustls-ring-webpki" ,
102
- feature = "_tls-rustls-ring-native-roots"
103
- ) ) ]
104
- let provider = Arc :: new ( rustls:: crypto:: ring:: default_provider ( ) ) ;
105
-
106
- // Unwrapping is safe here because we use a default provider.
107
- let config = ClientConfig :: builder_with_provider ( provider. clone ( ) )
90
+ impl TlsConfig < ' _ > {
91
+ async fn rustls_config ( & self ) -> crate :: Result < ( rustls:: ClientConfig , & str ) , Error > {
92
+ let RawTlsConfig {
93
+ accept_invalid_certs,
94
+ accept_invalid_hostnames,
95
+ hostname,
96
+ root_cert,
97
+ client_cert,
98
+ client_key,
99
+ } = match self {
100
+ TlsConfig :: RawTlsConfig ( raw) => raw,
101
+ TlsConfig :: PrebuiltRustls { config, hostname } => {
102
+ return Ok ( ( ( * config) . to_owned ( ) , hostname) ) ;
103
+ }
104
+ } ;
105
+
106
+ #[ cfg( all(
107
+ feature = "_tls-rustls-aws-lc-rs" ,
108
+ not( feature = "_tls-rustls-ring-webpki" ) ,
109
+ not( feature = "_tls-rustls-ring-native-roots" )
110
+ ) ) ]
111
+ let config = ClientConfig :: builder_with_provider ( Arc :: new (
112
+ rustls:: crypto:: aws_lc_rs:: default_provider ( ) ,
113
+ ) )
108
114
. with_safe_default_protocol_versions ( )
109
115
. unwrap ( ) ;
116
+ #[ cfg( any(
117
+ feature = "_tls-rustls-ring-webpki" ,
118
+ feature = "_tls-rustls-ring-native-roots"
119
+ ) ) ]
120
+ let config =
121
+ ClientConfig :: builder_with_provider ( Arc :: new ( rustls:: crypto:: ring:: default_provider ( ) ) )
122
+ . with_safe_default_protocol_versions ( )
123
+ . unwrap ( ) ;
124
+ #[ cfg( all(
125
+ not( feature = "_tls-rustls-ring-webpki" ) ,
126
+ not( feature = "_tls-rustls-ring-native-roots" )
127
+ ) ) ]
128
+ let config = ClientConfig :: builder ( ) ;
129
+
130
+ // authentication using user's key and its associated certificate
131
+ let user_auth = match ( client_cert, client_key) {
132
+ ( Some ( cert) , Some ( key) ) => {
133
+ let cert_chain = certs_from_pem ( cert. data ( ) . await ?) ?;
134
+ let key_der = private_key_from_pem ( key. data ( ) . await ?) ?;
135
+ Some ( ( cert_chain, key_der) )
136
+ }
137
+ ( None , None ) => None ,
138
+ ( _, _) => {
139
+ return Err ( Error :: Configuration (
140
+ "user auth key and certs must be given together" . into ( ) ,
141
+ ) )
142
+ }
143
+ } ;
110
144
111
- // authentication using user's key and its associated certificate
112
- let user_auth = match ( tls_config. client_cert_path , tls_config. client_key_path ) {
113
- ( Some ( cert_path) , Some ( key_path) ) => {
114
- let cert_chain = certs_from_pem ( cert_path. data ( ) . await ?) ?;
115
- let key_der = private_key_from_pem ( key_path. data ( ) . await ?) ?;
116
- Some ( ( cert_chain, key_der) )
117
- }
118
- ( None , None ) => None ,
119
- ( _, _) => {
120
- return Err ( Error :: Configuration (
121
- "user auth key and certs must be given together" . into ( ) ,
122
- ) )
123
- }
124
- } ;
145
+ let provider = config. crypto_provider ( ) . clone ( ) ;
125
146
126
- let config = if tls_config. accept_invalid_certs {
127
- if let Some ( user_auth) = user_auth {
128
- config
129
- . dangerous ( )
130
- . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
131
- . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
132
- . map_err ( Error :: tls) ?
147
+ let config = if * accept_invalid_certs {
148
+ if let Some ( user_auth) = user_auth {
149
+ config
150
+ . dangerous ( )
151
+ . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
152
+ . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
153
+ . map_err ( Error :: tls) ?
154
+ } else {
155
+ config
156
+ . dangerous ( )
157
+ . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
158
+ . with_no_client_auth ( )
159
+ }
133
160
} else {
134
- config
135
- . dangerous ( )
136
- . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
137
- . with_no_client_auth ( )
138
- }
139
- } else {
140
- let mut cert_store = import_root_certs ( ) ;
161
+ let mut cert_store = import_root_certs ( ) ;
141
162
142
- if let Some ( ca) = tls_config . root_cert_path {
143
- let data = ca. data ( ) . await ?;
163
+ if let Some ( ca) = root_cert {
164
+ let data = ca. data ( ) . await ?;
144
165
145
- for result in CertificateDer :: pem_slice_iter ( & data) {
146
- let Ok ( cert) = result else {
147
- return Err ( Error :: Tls ( format ! ( "Invalid certificate {ca}" ) . into ( ) ) ) ;
148
- } ;
166
+ for result in CertificateDer :: pem_slice_iter ( & data) {
167
+ let Ok ( cert) = result else {
168
+ return Err ( Error :: Tls ( format ! ( "Invalid certificate {ca}" ) . into ( ) ) ) ;
169
+ } ;
149
170
150
- cert_store. add ( cert) . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
171
+ cert_store. add ( cert) . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
172
+ }
151
173
}
152
- }
153
-
154
- if tls_config. accept_invalid_hostnames {
155
- let verifier = WebPkiServerVerifier :: builder ( Arc :: new ( cert_store) )
156
- . build ( )
157
- . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
158
174
159
- if let Some ( user_auth) = user_auth {
175
+ if * accept_invalid_hostnames {
176
+ let verifier = WebPkiServerVerifier :: builder ( Arc :: new ( cert_store) )
177
+ . build ( )
178
+ . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
179
+
180
+ if let Some ( user_auth) = user_auth {
181
+ config
182
+ . dangerous ( )
183
+ . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier {
184
+ verifier,
185
+ } ) )
186
+ . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
187
+ . map_err ( Error :: tls) ?
188
+ } else {
189
+ config
190
+ . dangerous ( )
191
+ . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier {
192
+ verifier,
193
+ } ) )
194
+ . with_no_client_auth ( )
195
+ }
196
+ } else if let Some ( user_auth) = user_auth {
160
197
config
161
- . dangerous ( )
162
- . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier { verifier } ) )
198
+ . with_root_certificates ( cert_store)
163
199
. with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
164
200
. map_err ( Error :: tls) ?
165
201
} else {
166
202
config
167
- . dangerous ( )
168
- . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier { verifier } ) )
203
+ . with_root_certificates ( cert_store)
169
204
. with_no_client_auth ( )
170
205
}
171
- } else if let Some ( user_auth) = user_auth {
172
- config
173
- . with_root_certificates ( cert_store)
174
- . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
175
- . map_err ( Error :: tls) ?
176
- } else {
177
- config
178
- . with_root_certificates ( cert_store)
179
- . with_no_client_auth ( )
180
- }
181
- } ;
206
+ } ;
207
+
208
+ Ok ( ( config, hostname) )
209
+ }
210
+ }
182
211
183
- let host = ServerName :: try_from ( tls_config. hostname . to_owned ( ) ) . map_err ( Error :: tls) ?;
212
+ pub async fn handshake < S > ( socket : S , tls_config : TlsConfig < ' _ > ) -> Result < RustlsSocket < S > , Error >
213
+ where
214
+ S : Socket ,
215
+ {
216
+ let ( config, hostname) = tls_config. rustls_config ( ) . await ?;
217
+ let host = ServerName :: try_from ( hostname. to_owned ( ) ) . map_err ( Error :: tls) ?;
184
218
185
219
let mut socket = RustlsSocket {
186
220
inner : StdSocket :: new ( socket) ,
0 commit comments