@@ -14,6 +14,7 @@ use rust_decimal::Decimal;
1414use serde:: { Deserialize , Deserializer , Serialize , Serializer } ;
1515use serde_json:: json;
1616use tokio:: {
17+ io:: AsyncWriteExt ,
1718 net:: { TcpListener , TcpStream } ,
1819 sync:: Mutex ,
1920 task:: JoinHandle ,
@@ -50,13 +51,41 @@ pub async fn start_ws_server(
5051}
5152
5253async fn accept_connection (
53- stream : TcpStream ,
54+ mut stream : TcpStream ,
5455 ws_client : Arc < PubsubClient > ,
5556 wallet : Arc < Wallet > ,
5657 program_data : & ' static ProgramData ,
5758) {
5859 let addr = stream. peer_addr ( ) . expect ( "peer address" ) ;
59- let ws_stream = accept_async ( stream) . await . expect ( "Ws handshake" ) ;
60+
61+ // Check for WebSocket upgrade header before accept_async consumes the stream
62+ let mut buf = [ 0u8 ; 1024 ] ;
63+ let n = stream. peek ( & mut buf) . await . unwrap_or ( 0 ) ;
64+ if !buf[ ..n]
65+ . windows ( 7 )
66+ . any ( |w| w. eq_ignore_ascii_case ( b"upgrade" ) )
67+ {
68+ warn ! ( target: LOG_TARGET , "non-WebSocket request from {}, rejecting" , addr) ;
69+ let _ = stream
70+ . write_all (
71+ b"HTTP/1.1 426 Upgrade Required\r \n \
72+ Content-Type: text/plain\r \n \
73+ Upgrade: websocket\r \n \
74+ Connection: Upgrade\r \n \
75+ \r \n \
76+ This is a WebSocket endpoint. Use a WebSocket client to connect.\n ",
77+ )
78+ . await ;
79+ return ;
80+ }
81+
82+ let ws_stream = match accept_async ( stream) . await {
83+ Ok ( ws) => ws,
84+ Err ( err) => {
85+ warn ! ( target: LOG_TARGET , "Ws handshake failed from {}: {}" , addr, err) ;
86+ return ;
87+ }
88+ } ;
6089 info ! ( target: LOG_TARGET , "accepted Ws connection: {}" , addr) ;
6190
6291 let ( mut ws_out, mut ws_in) = ws_stream. split ( ) ;
0 commit comments