1+ use std:: { net:: SocketAddr , sync:: Arc } ;
12use thiserror:: Error ;
3+ use tokio:: net:: { TcpListener , ToSocketAddrs } ;
24use tokio_tungstenite:: { tungstenite:: protocol:: WebSocketConfig , WebSocketStream } ;
35
46use crate :: {
@@ -15,9 +17,24 @@ pub struct AttestedWsServer {
1517 pub inner : AttestedTlsServer ,
1618 /// Optional websocket configuration
1719 pub websocket_config : Option < WebSocketConfig > ,
20+ listener : Arc < TcpListener > ,
1821}
1922
2023impl AttestedWsServer {
24+ pub async fn new (
25+ addr : impl ToSocketAddrs ,
26+ inner : AttestedTlsServer ,
27+ websocket_config : Option < WebSocketConfig > ,
28+ ) -> Result < Self , AttestedWsError > {
29+ let listener = TcpListener :: bind ( addr) . await ?;
30+
31+ Ok ( Self {
32+ listener : listener. into ( ) ,
33+ inner,
34+ websocket_config,
35+ } )
36+ }
37+
2138 /// Accept a Websocket connection
2239 pub async fn accept (
2340 & self ,
@@ -29,21 +46,20 @@ impl AttestedWsServer {
2946 ) ,
3047 AttestedWsError ,
3148 > {
32- let ( stream, measurements, attestation_type) = self . inner . accept ( ) . await ?;
49+ let ( tcp_stream, _addr) = self . listener . accept ( ) . await ?;
50+
51+ let ( stream, measurements, attestation_type) =
52+ self . inner . handle_connection ( tcp_stream) . await ?;
3353 Ok ( (
3454 tokio_tungstenite:: accept_async_with_config ( stream, self . websocket_config ) . await ?,
3555 measurements,
3656 attestation_type,
3757 ) )
3858 }
39- }
4059
41- impl From < AttestedTlsServer > for AttestedWsServer {
42- fn from ( inner : AttestedTlsServer ) -> Self {
43- Self {
44- inner,
45- websocket_config : None ,
46- }
60+ /// Helper to get the socket address of the underlying TCP listener
61+ pub fn local_addr ( & self ) -> std:: io:: Result < SocketAddr > {
62+ self . listener . local_addr ( )
4763 }
4864}
4965
@@ -68,7 +84,7 @@ impl AttestedWsClient {
6884 ) ,
6985 AttestedWsError ,
7086 > {
71- let ( stream, measurements, attestation_type) = self . inner . connect ( server) . await ?;
87+ let ( stream, measurements, attestation_type) = self . inner . connect_tcp ( server) . await ?;
7288 let ( ws_connection, _response) = tokio_tungstenite:: client_async_with_config (
7389 format ! ( "wss://{server}" ) ,
7490 stream,
@@ -95,6 +111,8 @@ pub enum AttestedWsError {
95111 Rustls ( #[ from] AttestedTlsError ) ,
96112 #[ error( "Websockets: {0}" ) ]
97113 Tungstenite ( #[ from] tokio_tungstenite:: tungstenite:: Error ) ,
114+ #[ error( "IO: {0}" ) ]
115+ Io ( #[ from] std:: io:: Error ) ,
98116}
99117
100118#[ cfg( test) ]
@@ -116,16 +134,17 @@ mod tests {
116134 let server = AttestedTlsServer :: new_with_tls_config (
117135 cert_chain,
118136 server_config,
119- "127.0.0.1:0" ,
120137 AttestationGenerator :: new_not_dummy ( AttestationType :: DcapTdx ) . unwrap ( ) ,
121138 AttestationVerifier :: expect_none ( ) ,
122139 )
123140 . await
124141 . unwrap ( ) ;
125142
126- let server_addr = server. local_addr ( ) . unwrap ( ) ;
143+ let ws_server = AttestedWsServer :: new ( "127.0.0.1:0" , server, None )
144+ . await
145+ . unwrap ( ) ;
127146
128- let ws_server : AttestedWsServer = server . into ( ) ;
147+ let server_addr = ws_server . local_addr ( ) . unwrap ( ) ;
129148
130149 tokio:: spawn ( async move {
131150 let ( mut ws_connection, _measurements, _attestation_type) =
0 commit comments