@@ -19,7 +19,7 @@ use rustls::{
19
19
use crate :: error:: Error ;
20
20
use crate :: io:: ReadBuf ;
21
21
use crate :: net:: tls:: util:: StdSocket ;
22
- use crate :: net:: tls:: TlsConfig ;
22
+ use crate :: net:: tls:: { RawTlsConfig , TlsConfig } ;
23
23
use crate :: net:: Socket ;
24
24
25
25
pub struct RustlsSocket < S : Socket > {
@@ -87,100 +87,135 @@ impl<S: Socket> Socket for RustlsSocket<S> {
87
87
}
88
88
}
89
89
90
- pub async fn handshake < S > ( socket : S , tls_config : TlsConfig < ' _ > ) -> Result < RustlsSocket < S > , Error >
91
- where
92
- S : Socket ,
93
- {
94
- #[ cfg( all(
95
- feature = "_tls-rustls-aws-lc-rs" ,
96
- not( feature = "_tls-rustls-ring-webpki" ) ,
97
- not( feature = "_tls-rustls-ring-native-roots" )
98
- ) ) ]
99
- let provider = Arc :: new ( rustls:: crypto:: aws_lc_rs:: default_provider ( ) ) ;
100
- #[ cfg( any(
101
- feature = "_tls-rustls-ring-webpki" ,
102
- feature = "_tls-rustls-ring-native-roots"
103
- ) ) ]
104
- let provider = Arc :: new ( rustls:: crypto:: ring:: default_provider ( ) ) ;
105
-
106
- // Unwrapping is safe here because we use a default provider.
107
- let config = ClientConfig :: builder_with_provider ( provider. clone ( ) )
90
+ impl TlsConfig < ' _ > {
91
+ async fn rustls_config ( & self ) -> crate :: Result < ( rustls:: ClientConfig , & str ) , Error > {
92
+ let RawTlsConfig {
93
+ accept_invalid_certs,
94
+ accept_invalid_hostnames,
95
+ hostname,
96
+ root_cert,
97
+ client_cert,
98
+ client_key,
99
+ } = match self {
100
+ TlsConfig :: RawTlsConfig ( raw) => raw,
101
+ TlsConfig :: PrebuiltRustls { config, hostname } => {
102
+ return Ok ( ( ( * config) . to_owned ( ) , hostname) ) ;
103
+ }
104
+ } ;
105
+
106
+ #[ cfg( all(
107
+ feature = "_tls-rustls-aws-lc-rs" ,
108
+ not( feature = "_tls-rustls-ring-webpki" ) ,
109
+ not( feature = "_tls-rustls-ring-native-roots" )
110
+ ) ) ]
111
+ let config = ClientConfig :: builder_with_provider ( Arc :: new (
112
+ rustls:: crypto:: aws_lc_rs:: default_provider ( ) ,
113
+ ) )
108
114
. with_safe_default_protocol_versions ( )
109
115
. unwrap ( ) ;
116
+ #[ cfg( any(
117
+ feature = "_tls-rustls-ring-webpki" ,
118
+ feature = "_tls-rustls-ring-native-roots"
119
+ ) ) ]
120
+ let config =
121
+ ClientConfig :: builder_with_provider ( Arc :: new ( rustls:: crypto:: ring:: default_provider ( ) ) )
122
+ . with_safe_default_protocol_versions ( )
123
+ . unwrap ( ) ;
124
+ #[ cfg( all(
125
+ not( feature = "_tls-rustls-aws-lc-rs" ) ,
126
+ not( feature = "_tls-rustls-ring-webpki" ) ,
127
+ not( feature = "_tls-rustls-ring-native-roots" )
128
+ ) ) ]
129
+ let config = ClientConfig :: builder ( ) ;
130
+
131
+ // authentication using user's key and its associated certificate
132
+ let user_auth = match ( client_cert, client_key) {
133
+ ( Some ( cert) , Some ( key) ) => {
134
+ let cert_chain = certs_from_pem ( cert. data ( ) . await ?) ?;
135
+ let key_der = private_key_from_pem ( key. data ( ) . await ?) ?;
136
+ Some ( ( cert_chain, key_der) )
137
+ }
138
+ ( None , None ) => None ,
139
+ ( _, _) => {
140
+ return Err ( Error :: Configuration (
141
+ "user auth key and certs must be given together" . into ( ) ,
142
+ ) )
143
+ }
144
+ } ;
110
145
111
- // authentication using user's key and its associated certificate
112
- let user_auth = match ( tls_config. client_cert_path , tls_config. client_key_path ) {
113
- ( Some ( cert_path) , Some ( key_path) ) => {
114
- let cert_chain = certs_from_pem ( cert_path. data ( ) . await ?) ?;
115
- let key_der = private_key_from_pem ( key_path. data ( ) . await ?) ?;
116
- Some ( ( cert_chain, key_der) )
117
- }
118
- ( None , None ) => None ,
119
- ( _, _) => {
120
- return Err ( Error :: Configuration (
121
- "user auth key and certs must be given together" . into ( ) ,
122
- ) )
123
- }
124
- } ;
146
+ let provider = config. crypto_provider ( ) . clone ( ) ;
125
147
126
- let config = if tls_config. accept_invalid_certs {
127
- if let Some ( user_auth) = user_auth {
128
- config
129
- . dangerous ( )
130
- . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
131
- . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
132
- . map_err ( Error :: tls) ?
148
+ let config = if * accept_invalid_certs {
149
+ if let Some ( user_auth) = user_auth {
150
+ config
151
+ . dangerous ( )
152
+ . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
153
+ . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
154
+ . map_err ( Error :: tls) ?
155
+ } else {
156
+ config
157
+ . dangerous ( )
158
+ . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
159
+ . with_no_client_auth ( )
160
+ }
133
161
} else {
134
- config
135
- . dangerous ( )
136
- . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
137
- . with_no_client_auth ( )
138
- }
139
- } else {
140
- let mut cert_store = import_root_certs ( ) ;
162
+ let mut cert_store = import_root_certs ( ) ;
141
163
142
- if let Some ( ca) = tls_config . root_cert_path {
143
- let data = ca. data ( ) . await ?;
164
+ if let Some ( ca) = root_cert {
165
+ let data = ca. data ( ) . await ?;
144
166
145
- for result in CertificateDer :: pem_slice_iter ( & data) {
146
- let Ok ( cert) = result else {
147
- return Err ( Error :: Tls ( format ! ( "Invalid certificate {ca}" ) . into ( ) ) ) ;
148
- } ;
167
+ for result in CertificateDer :: pem_slice_iter ( & data) {
168
+ let Ok ( cert) = result else {
169
+ return Err ( Error :: Tls ( format ! ( "Invalid certificate {ca}" ) . into ( ) ) ) ;
170
+ } ;
149
171
150
- cert_store. add ( cert) . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
172
+ cert_store. add ( cert) . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
173
+ }
151
174
}
152
- }
153
-
154
- if tls_config. accept_invalid_hostnames {
155
- let verifier = WebPkiServerVerifier :: builder ( Arc :: new ( cert_store) )
156
- . build ( )
157
- . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
158
175
159
- if let Some ( user_auth) = user_auth {
176
+ if * accept_invalid_hostnames {
177
+ let verifier = WebPkiServerVerifier :: builder ( Arc :: new ( cert_store) )
178
+ . build ( )
179
+ . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
180
+
181
+ if let Some ( user_auth) = user_auth {
182
+ config
183
+ . dangerous ( )
184
+ . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier {
185
+ verifier,
186
+ } ) )
187
+ . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
188
+ . map_err ( Error :: tls) ?
189
+ } else {
190
+ config
191
+ . dangerous ( )
192
+ . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier {
193
+ verifier,
194
+ } ) )
195
+ . with_no_client_auth ( )
196
+ }
197
+ } else if let Some ( user_auth) = user_auth {
160
198
config
161
- . dangerous ( )
162
- . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier { verifier } ) )
199
+ . with_root_certificates ( cert_store)
163
200
. with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
164
201
. map_err ( Error :: tls) ?
165
202
} else {
166
203
config
167
- . dangerous ( )
168
- . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier { verifier } ) )
204
+ . with_root_certificates ( cert_store)
169
205
. with_no_client_auth ( )
170
206
}
171
- } else if let Some ( user_auth) = user_auth {
172
- config
173
- . with_root_certificates ( cert_store)
174
- . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
175
- . map_err ( Error :: tls) ?
176
- } else {
177
- config
178
- . with_root_certificates ( cert_store)
179
- . with_no_client_auth ( )
180
- }
181
- } ;
207
+ } ;
208
+
209
+ Ok ( ( config, hostname) )
210
+ }
211
+ }
182
212
183
- let host = ServerName :: try_from ( tls_config. hostname . to_owned ( ) ) . map_err ( Error :: tls) ?;
213
+ pub async fn handshake < S > ( socket : S , tls_config : TlsConfig < ' _ > ) -> Result < RustlsSocket < S > , Error >
214
+ where
215
+ S : Socket ,
216
+ {
217
+ let ( config, hostname) = tls_config. rustls_config ( ) . await ?;
218
+ let host = ServerName :: try_from ( hostname. to_owned ( ) ) . map_err ( Error :: tls) ?;
184
219
185
220
let mut socket = RustlsSocket {
186
221
inner : StdSocket :: new ( socket) ,
0 commit comments