Skip to content

Commit 54be39f

Browse files
committed
Update following merging main
1 parent 8aec1a1 commit 54be39f

File tree

2 files changed

+38
-16
lines changed

2 files changed

+38
-16
lines changed

src/attested_tls.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,7 @@ fn server_name_from_host(
505505
mod tests {
506506
use super::*;
507507
use crate::test_helpers::{generate_certificate_chain, generate_tls_config};
508+
use tokio::net::TcpListener;
508509

509510
#[tokio::test]
510511
async fn server_attestation() {
@@ -514,17 +515,19 @@ mod tests {
514515
let server = AttestedTlsServer::new_with_tls_config(
515516
cert_chain,
516517
server_config,
517-
"127.0.0.1:0",
518518
AttestationGenerator::new_not_dummy(AttestationType::DcapTdx).unwrap(),
519519
AttestationVerifier::expect_none(),
520520
)
521521
.await
522522
.unwrap();
523523

524-
let server_addr = server.local_addr().unwrap();
524+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
525+
let server_addr = listener.local_addr().unwrap();
525526

526527
tokio::spawn(async move {
527-
let (_stream, _measurements, _attestation_type) = server.accept().await.unwrap();
528+
let (tcp_stream, _) = listener.accept().await.unwrap();
529+
let (_stream, _measurements, _attestation_type) =
530+
server.handle_connection(tcp_stream).await.unwrap();
528531
});
529532

530533
let client = AttestedTlsClient::new_with_tls_config(
@@ -537,6 +540,6 @@ mod tests {
537540
.unwrap();
538541

539542
let (_stream, _measurements, _attestation_type) =
540-
client.connect(&server_addr.to_string()).await.unwrap();
543+
client.connect_tcp(&server_addr.to_string()).await.unwrap();
541544
}
542545
}

src/websockets.rs

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
use std::{net::SocketAddr, sync::Arc};
12
use thiserror::Error;
3+
use tokio::net::{TcpListener, ToSocketAddrs};
24
use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
35

46
use crate::{
@@ -15,9 +17,24 @@ pub struct AttestedWsServer {
1517
pub inner: AttestedTlsServer,
1618
/// Optional websocket configuration
1719
pub websocket_config: Option<WebSocketConfig>,
20+
listener: Arc<TcpListener>,
1821
}
1922

2023
impl AttestedWsServer {
24+
pub async fn new(
25+
addr: impl ToSocketAddrs,
26+
inner: AttestedTlsServer,
27+
websocket_config: Option<WebSocketConfig>,
28+
) -> Result<Self, AttestedWsError> {
29+
let listener = TcpListener::bind(addr).await?;
30+
31+
Ok(Self {
32+
listener: listener.into(),
33+
inner,
34+
websocket_config,
35+
})
36+
}
37+
2138
/// Accept a Websocket connection
2239
pub async fn accept(
2340
&self,
@@ -29,21 +46,20 @@ impl AttestedWsServer {
2946
),
3047
AttestedWsError,
3148
> {
32-
let (stream, measurements, attestation_type) = self.inner.accept().await?;
49+
let (tcp_stream, _addr) = self.listener.accept().await?;
50+
51+
let (stream, measurements, attestation_type) =
52+
self.inner.handle_connection(tcp_stream).await?;
3353
Ok((
3454
tokio_tungstenite::accept_async_with_config(stream, self.websocket_config).await?,
3555
measurements,
3656
attestation_type,
3757
))
3858
}
39-
}
4059

41-
impl From<AttestedTlsServer> for AttestedWsServer {
42-
fn from(inner: AttestedTlsServer) -> Self {
43-
Self {
44-
inner,
45-
websocket_config: None,
46-
}
60+
/// Helper to get the socket address of the underlying TCP listener
61+
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
62+
self.listener.local_addr()
4763
}
4864
}
4965

@@ -68,7 +84,7 @@ impl AttestedWsClient {
6884
),
6985
AttestedWsError,
7086
> {
71-
let (stream, measurements, attestation_type) = self.inner.connect(server).await?;
87+
let (stream, measurements, attestation_type) = self.inner.connect_tcp(server).await?;
7288
let (ws_connection, _response) = tokio_tungstenite::client_async_with_config(
7389
format!("wss://{server}"),
7490
stream,
@@ -95,6 +111,8 @@ pub enum AttestedWsError {
95111
Rustls(#[from] AttestedTlsError),
96112
#[error("Websockets: {0}")]
97113
Tungstenite(#[from] tokio_tungstenite::tungstenite::Error),
114+
#[error("IO: {0}")]
115+
Io(#[from] std::io::Error),
98116
}
99117

100118
#[cfg(test)]
@@ -116,16 +134,17 @@ mod tests {
116134
let server = AttestedTlsServer::new_with_tls_config(
117135
cert_chain,
118136
server_config,
119-
"127.0.0.1:0",
120137
AttestationGenerator::new_not_dummy(AttestationType::DcapTdx).unwrap(),
121138
AttestationVerifier::expect_none(),
122139
)
123140
.await
124141
.unwrap();
125142

126-
let server_addr = server.local_addr().unwrap();
143+
let ws_server = AttestedWsServer::new("127.0.0.1:0", server, None)
144+
.await
145+
.unwrap();
127146

128-
let ws_server: AttestedWsServer = server.into();
147+
let server_addr = ws_server.local_addr().unwrap();
129148

130149
tokio::spawn(async move {
131150
let (mut ws_connection, _measurements, _attestation_type) =

0 commit comments

Comments
 (0)