11use thiserror:: Error ;
2- use tokio_tungstenite:: WebSocketStream ;
2+ use tokio_tungstenite:: { tungstenite :: protocol :: WebSocketConfig , WebSocketStream } ;
33
44use crate :: {
55 attestation:: { measurements:: MultiMeasurements , AttestationType } ,
@@ -9,12 +9,16 @@ use crate::{
99/// Websocket message type re-exported for convenience
1010pub use tokio_tungstenite:: tungstenite:: protocol:: Message ;
1111
12- // TODO allow setting ws config
12+ /// An attested Websocket server
1313pub struct AttestedWsServer {
14- inner : AttestedTlsServer ,
14+ /// The underlying attested TLS server
15+ pub inner : AttestedTlsServer ,
16+ /// Optional websocket configuration
17+ pub websocket_config : Option < WebSocketConfig > ,
1518}
1619
1720impl AttestedWsServer {
21+ /// Accept a Websocket connection
1822 pub async fn accept (
1923 & self ,
2024 ) -> Result <
@@ -27,18 +31,32 @@ impl AttestedWsServer {
2731 > {
2832 let ( stream, measurements, attestation_type) = self . inner . accept ( ) . await ?;
2933 Ok ( (
30- tokio_tungstenite:: accept_async ( stream) . await ?,
34+ tokio_tungstenite:: accept_async_with_config ( stream, self . websocket_config ) . await ?,
3135 measurements,
3236 attestation_type,
3337 ) )
3438 }
3539}
3640
41+ impl From < AttestedTlsServer > for AttestedWsServer {
42+ fn from ( inner : AttestedTlsServer ) -> Self {
43+ Self {
44+ inner,
45+ websocket_config : None ,
46+ }
47+ }
48+ }
49+
50+ /// An attested Websocket client
3751pub struct AttestedWsClient {
38- inner : AttestedTlsClient ,
52+ /// The underlying attested TLS client
53+ pub inner : AttestedTlsClient ,
54+ /// Optional websocket configuration
55+ pub websocket_config : Option < WebSocketConfig > ,
3956}
4057
4158impl AttestedWsClient {
59+ /// Make a Websocket connection
4260 pub async fn connect (
4361 & self ,
4462 server : & str ,
@@ -51,13 +69,26 @@ impl AttestedWsClient {
5169 AttestedWsError ,
5270 > {
5371 let ( stream, measurements, attestation_type) = self . inner . connect ( server) . await ?;
54- let ( ws_connection, _response) =
55- tokio_tungstenite:: client_async ( format ! ( "wss://{server}" ) , stream) . await ?;
72+ let ( ws_connection, _response) = tokio_tungstenite:: client_async_with_config (
73+ format ! ( "wss://{server}" ) ,
74+ stream,
75+ self . websocket_config ,
76+ )
77+ . await ?;
5678
5779 Ok ( ( ws_connection, measurements, attestation_type) )
5880 }
5981}
6082
83+ impl From < AttestedTlsClient > for AttestedWsClient {
84+ fn from ( inner : AttestedTlsClient ) -> Self {
85+ Self {
86+ inner,
87+ websocket_config : None ,
88+ }
89+ }
90+ }
91+
6192#[ derive( Error , Debug ) ]
6293pub enum AttestedWsError {
6394 #[ error( "Attested TLS: {0}" ) ]
@@ -94,7 +125,7 @@ mod tests {
94125
95126 let server_addr = server. local_addr ( ) . unwrap ( ) ;
96127
97- let ws_server = AttestedWsServer { inner : server } ;
128+ let ws_server: AttestedWsServer = server. into ( ) ;
98129
99130 tokio:: spawn ( async move {
100131 let ( mut ws_connection, _measurements, _attestation_type) =
@@ -115,7 +146,7 @@ mod tests {
115146 . await
116147 . unwrap ( ) ;
117148
118- let ws_client = AttestedWsClient { inner : client } ;
149+ let ws_client: AttestedWsClient = client. into ( ) ;
119150
120151 let ( mut ws_connection, _measurements, _attestation_type) =
121152 ws_client. connect ( & server_addr. to_string ( ) ) . await . unwrap ( ) ;
0 commit comments