Skip to content

Commit 9d2fabf

Browse files
committed
Tidy, allow config to be passed in
1 parent a9c7f4a commit 9d2fabf

File tree

1 file changed

+40
-9
lines changed

1 file changed

+40
-9
lines changed

src/websockets.rs

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use thiserror::Error;
2-
use tokio_tungstenite::WebSocketStream;
2+
use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
33

44
use crate::{
55
attestation::{measurements::MultiMeasurements, AttestationType},
@@ -9,12 +9,16 @@ use crate::{
99
/// Websocket message type re-exported for convenience
1010
pub use tokio_tungstenite::tungstenite::protocol::Message;
1111

12-
// TODO allow setting ws config
12+
/// An attested Websocket server
1313
pub struct AttestedWsServer {
14-
inner: AttestedTlsServer,
14+
/// The underlying attested TLS server
15+
pub inner: AttestedTlsServer,
16+
/// Optional websocket configuration
17+
pub websocket_config: Option<WebSocketConfig>,
1518
}
1619

1720
impl AttestedWsServer {
21+
/// Accept a Websocket connection
1822
pub async fn accept(
1923
&self,
2024
) -> Result<
@@ -27,18 +31,32 @@ impl AttestedWsServer {
2731
> {
2832
let (stream, measurements, attestation_type) = self.inner.accept().await?;
2933
Ok((
30-
tokio_tungstenite::accept_async(stream).await?,
34+
tokio_tungstenite::accept_async_with_config(stream, self.websocket_config).await?,
3135
measurements,
3236
attestation_type,
3337
))
3438
}
3539
}
3640

41+
impl From<AttestedTlsServer> for AttestedWsServer {
42+
fn from(inner: AttestedTlsServer) -> Self {
43+
Self {
44+
inner,
45+
websocket_config: None,
46+
}
47+
}
48+
}
49+
50+
/// An attested Websocket client
3751
pub struct AttestedWsClient {
38-
inner: AttestedTlsClient,
52+
/// The underlying attested TLS client
53+
pub inner: AttestedTlsClient,
54+
/// Optional websocket configuration
55+
pub websocket_config: Option<WebSocketConfig>,
3956
}
4057

4158
impl AttestedWsClient {
59+
/// Make a Websocket connection
4260
pub async fn connect(
4361
&self,
4462
server: &str,
@@ -51,13 +69,26 @@ impl AttestedWsClient {
5169
AttestedWsError,
5270
> {
5371
let (stream, measurements, attestation_type) = self.inner.connect(server).await?;
54-
let (ws_connection, _response) =
55-
tokio_tungstenite::client_async(format!("wss://{server}"), stream).await?;
72+
let (ws_connection, _response) = tokio_tungstenite::client_async_with_config(
73+
format!("wss://{server}"),
74+
stream,
75+
self.websocket_config,
76+
)
77+
.await?;
5678

5779
Ok((ws_connection, measurements, attestation_type))
5880
}
5981
}
6082

83+
impl From<AttestedTlsClient> for AttestedWsClient {
84+
fn from(inner: AttestedTlsClient) -> Self {
85+
Self {
86+
inner,
87+
websocket_config: None,
88+
}
89+
}
90+
}
91+
6192
#[derive(Error, Debug)]
6293
pub enum AttestedWsError {
6394
#[error("Attested TLS: {0}")]
@@ -94,7 +125,7 @@ mod tests {
94125

95126
let server_addr = server.local_addr().unwrap();
96127

97-
let ws_server = AttestedWsServer { inner: server };
128+
let ws_server: AttestedWsServer = server.into();
98129

99130
tokio::spawn(async move {
100131
let (mut ws_connection, _measurements, _attestation_type) =
@@ -115,7 +146,7 @@ mod tests {
115146
.await
116147
.unwrap();
117148

118-
let ws_client = AttestedWsClient { inner: client };
149+
let ws_client: AttestedWsClient = client.into();
119150

120151
let (mut ws_connection, _measurements, _attestation_type) =
121152
ws_client.connect(&server_addr.to_string()).await.unwrap();

0 commit comments

Comments
 (0)