Skip to content

Commit 0f388fc

Browse files
authored
Update websocket error handling (#148)
* Update websocket error handling * Format
1 parent 056c890 commit 0f388fc

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

src/websocket.rs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use rust_decimal::Decimal;
1414
use serde::{Deserialize, Deserializer, Serialize, Serializer};
1515
use serde_json::json;
1616
use tokio::{
17+
io::AsyncWriteExt,
1718
net::{TcpListener, TcpStream},
1819
sync::Mutex,
1920
task::JoinHandle,
@@ -50,13 +51,41 @@ pub async fn start_ws_server(
5051
}
5152

5253
async fn accept_connection(
53-
stream: TcpStream,
54+
mut stream: TcpStream,
5455
ws_client: Arc<PubsubClient>,
5556
wallet: Arc<Wallet>,
5657
program_data: &'static ProgramData,
5758
) {
5859
let addr = stream.peer_addr().expect("peer address");
59-
let ws_stream = accept_async(stream).await.expect("Ws handshake");
60+
61+
// Check for WebSocket upgrade header before accept_async consumes the stream
62+
let mut buf = [0u8; 1024];
63+
let n = stream.peek(&mut buf).await.unwrap_or(0);
64+
if !buf[..n]
65+
.windows(7)
66+
.any(|w| w.eq_ignore_ascii_case(b"upgrade"))
67+
{
68+
warn!(target: LOG_TARGET, "non-WebSocket request from {}, rejecting", addr);
69+
let _ = stream
70+
.write_all(
71+
b"HTTP/1.1 426 Upgrade Required\r\n\
72+
Content-Type: text/plain\r\n\
73+
Upgrade: websocket\r\n\
74+
Connection: Upgrade\r\n\
75+
\r\n\
76+
This is a WebSocket endpoint. Use a WebSocket client to connect.\n",
77+
)
78+
.await;
79+
return;
80+
}
81+
82+
let ws_stream = match accept_async(stream).await {
83+
Ok(ws) => ws,
84+
Err(err) => {
85+
warn!(target: LOG_TARGET, "Ws handshake failed from {}: {}", addr, err);
86+
return;
87+
}
88+
};
6089
info!(target: LOG_TARGET, "accepted Ws connection: {}", addr);
6190

6291
let (mut ws_out, mut ws_in) = ws_stream.split();

0 commit comments

Comments
 (0)