Skip to content

Commit 2782913

Browse files
committed
Improve handling of hostnames
1 parent 5af89e4 commit 2782913

File tree

3 files changed

+76
-79
lines changed

3 files changed

+76
-79
lines changed

src/lib.rs

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,8 @@ where
218218
{
219219
inner: Proxy<L, R>,
220220
connector: TlsConnector,
221-
/// The address of the proxy server
222-
target: SocketAddr,
223-
/// The subject name of the proxy server
224-
target_name: ServerName<'static>,
221+
/// The host and port of the proxy server
222+
target: String,
225223
/// Certificate chain for client auth
226224
cert_chain: Option<Vec<CertificateDer<'static>>>,
227225
}
@@ -230,8 +228,7 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
230228
pub async fn new(
231229
cert_and_key: Option<TlsCertAndKey>,
232230
address: impl ToSocketAddrs,
233-
server_address: SocketAddr,
234-
server_name: ServerName<'static>,
231+
server_name: String,
235232
local_attestation_platform: L,
236233
remote_attestation_platform: R,
237234
) -> Result<Self, ProxyError> {
@@ -257,7 +254,6 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
257254
Self::new_with_tls_config(
258255
client_config.into(),
259256
address,
260-
server_address,
261257
server_name,
262258
local_attestation_platform,
263259
remote_attestation_platform,
@@ -272,8 +268,7 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
272268
async fn new_with_tls_config(
273269
client_config: Arc<ClientConfig>,
274270
local: impl ToSocketAddrs,
275-
target: SocketAddr,
276-
target_name: ServerName<'static>,
271+
target_name: String,
277272
local_attestation_platform: L,
278273
remote_attestation_platform: R,
279274
cert_chain: Option<Vec<CertificateDer<'static>>>,
@@ -290,8 +285,7 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
290285
Ok(Self {
291286
inner,
292287
connector,
293-
target,
294-
target_name,
288+
target: host_to_host_with_port(&target_name),
295289
cert_chain,
296290
})
297291
}
@@ -301,8 +295,7 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
301295
let (inbound, _client_addr) = self.inner.listener.accept().await?;
302296

303297
let connector = self.connector.clone();
304-
let target_name = self.target_name.clone();
305-
let target = self.target;
298+
let target = self.target.clone();
306299
let local_attestation_platform = self.inner.local_attestation_platform.clone();
307300
let remote_attestation_platform = self.inner.remote_attestation_platform.clone();
308301
let cert_chain = self.cert_chain.clone();
@@ -312,7 +305,6 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
312305
inbound,
313306
connector,
314307
target,
315-
target_name,
316308
cert_chain,
317309
local_attestation_platform,
318310
remote_attestation_platform,
@@ -335,14 +327,15 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
335327
async fn handle_connection(
336328
inbound: TcpStream,
337329
connector: TlsConnector,
338-
target: SocketAddr,
339-
target_name: ServerName<'static>,
330+
target: String,
340331
cert_chain: Option<Vec<CertificateDer<'static>>>,
341332
local_attestation_platform: L,
342333
remote_attestation_platform: R,
343334
) -> Result<(), ProxyError> {
344-
let out = TcpStream::connect(target).await?;
345-
let mut tls_stream = connector.connect(target_name, out).await?;
335+
let out = TcpStream::connect(&target).await?;
336+
let mut tls_stream = connector
337+
.connect(server_name_from_host(&target)?, out)
338+
.await?;
346339

347340
let (_io, server_connection) = tls_stream.get_ref();
348341

@@ -394,16 +387,14 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
394387

395388
/// Just get the attested remote certificate, with no client authentication
396389
pub async fn get_tls_cert<R: AttestationPlatform>(
397-
server_address: SocketAddr,
398-
server_name: ServerName<'static>,
390+
server_name: String,
399391
remote_attestation_platform: R,
400392
) -> Result<Vec<CertificateDer<'static>>, ProxyError> {
401393
let root_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
402394
let client_config = ClientConfig::builder()
403395
.with_root_certificates(root_store)
404396
.with_no_client_auth();
405397
get_tls_cert_with_config(
406-
server_address,
407398
server_name,
408399
remote_attestation_platform,
409400
client_config.into(),
@@ -412,15 +403,16 @@ pub async fn get_tls_cert<R: AttestationPlatform>(
412403
}
413404

414405
async fn get_tls_cert_with_config<R: AttestationPlatform>(
415-
server_address: SocketAddr,
416-
server_name: ServerName<'static>,
406+
server_name: String,
417407
remote_attestation_platform: R,
418408
client_config: Arc<ClientConfig>,
419409
) -> Result<Vec<CertificateDer<'static>>, ProxyError> {
420410
let connector = TlsConnector::from(client_config);
421411

422-
let out = TcpStream::connect(server_address).await?;
423-
let mut tls_stream = connector.connect(server_name, out).await?;
412+
let out = TcpStream::connect(host_to_host_with_port(&server_name)).await?;
413+
let mut tls_stream = connector
414+
.connect(server_name_from_host(&server_name)?, out)
415+
.await?;
424416

425417
let (_io, server_connection) = tls_stream.get_ref();
426418

@@ -467,6 +459,8 @@ pub enum ProxyError {
467459
Attestation(#[from] AttestationError),
468460
#[error("Integer conversion: {0}")]
469461
IntConversion(#[from] TryFromIntError),
462+
#[error("Bad host name: {0}")]
463+
BadDnsName(#[from] tokio_rustls::rustls::pki_types::InvalidDnsNameError),
470464
}
471465

472466
/// Given a byte array, encode its length as a 4 byte big endian u32
@@ -475,6 +469,27 @@ fn length_prefix(input: &[u8]) -> [u8; 4] {
475469
len.to_be_bytes()
476470
}
477471

472+
fn host_to_host_with_port(host: &str) -> String {
473+
if host.contains(':') {
474+
host.to_string()
475+
} else {
476+
format!("{host}:443")
477+
}
478+
}
479+
480+
fn server_name_from_host(
481+
host: &str,
482+
) -> Result<ServerName<'static>, tokio_rustls::rustls::pki_types::InvalidDnsNameError> {
483+
// If host contains ':', try to split off the port.
484+
let host_part = host.rsplit_once(':').map(|(h, _)| h).unwrap_or(host);
485+
486+
// If the host is an IPv6 literal in brackets like "[::1]:443",
487+
// remove the brackets for SNI (SNI allows bare IPv6 too).
488+
let host_part = host_part.trim_matches(|c| c == '[' || c == ']');
489+
490+
ServerName::try_from(host_part.to_string())
491+
}
492+
478493
#[cfg(test)]
479494
mod tests {
480495
use super::*;
@@ -486,9 +501,8 @@ mod tests {
486501
#[tokio::test]
487502
async fn http_proxy() {
488503
let target_addr = example_http_service().await;
489-
let target_name = "name".to_string();
490504

491-
let (cert_chain, private_key) = generate_certificate_chain(target_name.clone());
505+
let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap());
492506
let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key);
493507

494508
let proxy_server = ProxyServer::new_with_tls_config(
@@ -510,9 +524,8 @@ mod tests {
510524

511525
let proxy_client = ProxyClient::new_with_tls_config(
512526
client_config,
513-
"127.0.0.1:0",
514-
proxy_addr,
515-
target_name.try_into().unwrap(),
527+
"127.0.0.1:0".to_string(),
528+
proxy_addr.to_string(),
516529
NoAttestation,
517530
MockAttestation,
518531
None,
@@ -539,12 +552,11 @@ mod tests {
539552
#[tokio::test]
540553
async fn http_proxy_mutual_attestation() {
541554
let target_addr = example_http_service().await;
542-
let target_name = "name".to_string();
543555

544556
let (server_cert_chain, server_private_key) =
545-
generate_certificate_chain(target_name.clone());
557+
generate_certificate_chain("127.0.0.1".parse().unwrap());
546558
let (client_cert_chain, client_private_key) =
547-
generate_certificate_chain(target_name.clone());
559+
generate_certificate_chain("127.0.0.1".parse().unwrap());
548560

549561
let (
550562
(_client_tls_server_config, client_tls_client_config),
@@ -576,8 +588,7 @@ mod tests {
576588
let proxy_client = ProxyClient::new_with_tls_config(
577589
client_tls_client_config,
578590
"127.0.0.1:0",
579-
proxy_addr,
580-
target_name.try_into().unwrap(),
591+
proxy_addr.to_string(),
581592
MockAttestation,
582593
MockAttestation,
583594
Some(client_cert_chain),
@@ -604,9 +615,8 @@ mod tests {
604615
#[tokio::test]
605616
async fn raw_tcp_proxy() {
606617
let target_addr = example_service().await;
607-
let target_name = "name".to_string();
608618

609-
let (cert_chain, private_key) = generate_certificate_chain(target_name.clone());
619+
let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap());
610620
let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key);
611621

612622
let local_attestation_platform = MockAttestation;
@@ -631,8 +641,7 @@ mod tests {
631641
let proxy_client = ProxyClient::new_with_tls_config(
632642
client_config,
633643
"127.0.0.1:0",
634-
proxy_server_addr,
635-
target_name.try_into().unwrap(),
644+
proxy_server_addr.to_string(),
636645
NoAttestation,
637646
MockAttestation,
638647
None,
@@ -657,9 +666,8 @@ mod tests {
657666
#[tokio::test]
658667
async fn test_get_tls_cert() {
659668
let target_addr = example_service().await;
660-
let target_name = "name".to_string();
661669

662-
let (cert_chain, private_key) = generate_certificate_chain(target_name.clone());
670+
let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap());
663671
let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key);
664672

665673
let local_attestation_platform = MockAttestation;
@@ -682,8 +690,7 @@ mod tests {
682690
});
683691

684692
let retrieved_chain = get_tls_cert_with_config(
685-
proxy_server_addr,
686-
target_name.try_into().unwrap(),
693+
proxy_server_addr.to_string(),
687694
MockAttestation,
688695
client_config,
689696
)

src/main.rs

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,10 @@ enum CliCommand {
1919
/// Run a proxy client
2020
Client {
2121
/// Socket address to listen on
22-
#[arg(short, long)]
22+
#[arg(short, long, default_value = "0.0.0.0:0")]
2323
address: SocketAddr,
24-
/// The socket address of the proxy server
25-
#[arg(short, long)]
26-
server_address: SocketAddr,
27-
/// The domain name of the proxy server
28-
#[arg(long)]
29-
server_name: String,
24+
/// The hostname:port or ip:port of the proxy server (port defaults to 443)
25+
server: String,
3026
/// The path to a PEM encoded private key for client authentication
3127
#[arg(long)]
3228
private_key: Option<PathBuf>,
@@ -37,10 +33,9 @@ enum CliCommand {
3733
/// Run a proxy server
3834
Server {
3935
/// Socket address to listen on
40-
#[arg(short, long)]
36+
#[arg(short, long, default_value = "0.0.0.0:0")]
4137
address: SocketAddr,
4238
/// Socket address of the target service to forward traffic to
43-
#[arg(short, long)]
4439
target_address: SocketAddr,
4540
/// The path to a PEM encoded private key
4641
#[arg(long)]
@@ -55,12 +50,8 @@ enum CliCommand {
5550
},
5651
/// Retrieve the attested TLS certificate from a proxy server
5752
GetTlsCert {
58-
/// The socket address of the proxy server
59-
#[arg(short, long)]
60-
server_address: SocketAddr,
61-
/// The domain name of the proxy server
62-
#[arg(long)]
63-
server_name: String,
53+
/// The hostname:port or ip:port of the proxy server (port defaults to 443)
54+
server: String,
6455
},
6556
}
6657

@@ -71,8 +62,7 @@ async fn main() -> anyhow::Result<()> {
7162
match cli.command {
7263
CliCommand::Client {
7364
address,
74-
server_name,
75-
server_address,
65+
server,
7666
private_key,
7767
cert_chain,
7868
} => {
@@ -92,8 +82,7 @@ async fn main() -> anyhow::Result<()> {
9282
let client = ProxyClient::new(
9383
tls_cert_and_chain,
9484
address,
95-
server_address,
96-
server_name.try_into()?,
85+
server,
9786
NoAttestation,
9887
MockAttestation,
9988
)
@@ -132,12 +121,8 @@ async fn main() -> anyhow::Result<()> {
132121
}
133122
}
134123
}
135-
CliCommand::GetTlsCert {
136-
server_address,
137-
server_name,
138-
} => {
139-
let cert_chain =
140-
get_tls_cert(server_address, server_name.try_into()?, MockAttestation).await?;
124+
CliCommand::GetTlsCert { server } => {
125+
let cert_chain = get_tls_cert(server, MockAttestation).await?;
141126
println!("{}", certs_to_pem_string(&cert_chain)?);
142127
}
143128
}

src/test_helpers.rs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
use rcgen::generate_simple_self_signed;
2-
use std::{net::SocketAddr, sync::Arc};
1+
use std::{
2+
net::{IpAddr, SocketAddr},
3+
sync::Arc,
4+
};
35
use tokio::io::AsyncWriteExt;
46
use tokio::net::TcpListener;
57
use tokio_rustls::rustls::{
@@ -10,16 +12,19 @@ use tokio_rustls::rustls::{
1012

1113
/// Helper to generate a self-signed certificate for testing
1214
pub fn generate_certificate_chain(
13-
name: String,
15+
ip: IpAddr,
1416
) -> (Vec<CertificateDer<'static>>, PrivateKeyDer<'static>) {
15-
let subject_alt_names = vec![name];
16-
let cert_key = generate_simple_self_signed(subject_alt_names)
17-
.expect("Failed to generate self-signed certificate");
18-
19-
let certs = vec![CertificateDer::from(cert_key.cert)];
20-
let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(
21-
cert_key.signing_key.serialize_der(),
22-
));
17+
let mut params = rcgen::CertificateParams::new(vec![]).unwrap();
18+
params.subject_alt_names.push(rcgen::SanType::IpAddress(ip));
19+
params
20+
.distinguished_name
21+
.push(rcgen::DnType::CommonName, ip.to_string());
22+
23+
let keypair = rcgen::KeyPair::generate().unwrap();
24+
let cert = params.self_signed(&keypair).unwrap();
25+
26+
let certs = vec![CertificateDer::from(cert)];
27+
let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(keypair.serialize_der()));
2328
(certs, key)
2429
}
2530

0 commit comments

Comments
 (0)