@@ -146,6 +146,10 @@ impl HttpServer {
146146 if let Err ( e) = send_res {
147147 error!( "Websocket message send error: {:?}" , e)
148148 }
149+ if res. should_close_connection( ) {
150+ log:: warn!( "Websocket connection closed" ) ;
151+ break ;
152+ }
149153 }
150154 Some ( msg) = web_socket. next( ) => {
151155 match msg {
@@ -284,12 +288,21 @@ impl HttpServer {
284288 "Error processing HTTP command: {}\n " ,
285289 e. display_with_backtrace( )
286290 ) ;
291+ let command = if e. is_wrong_connection ( ) {
292+ HttpCommand :: CloseConnection {
293+ error : e. to_string ( ) ,
294+ }
295+
296+ } else {
297+ HttpCommand :: Error {
298+ error : e. to_string ( ) ,
299+ }
300+ } ;
301+
287302 HttpMessage {
288303 message_id,
289304 connection_id,
290- command : HttpCommand :: Error {
291- error : e. to_string ( ) ,
292- } ,
305+ command,
293306 }
294307 }
295308 } ) ;
@@ -576,6 +589,9 @@ pub enum HttpCommand {
576589 ResultSet {
577590 data_frame : Arc < DataFrame > ,
578591 } ,
592+ CloseConnection {
593+ error : String ,
594+ } ,
579595 Error {
580596 error : String ,
581597 } ,
@@ -589,7 +605,9 @@ impl HttpMessage {
589605 command_type : match self . command {
590606 HttpCommand :: Query { .. } => crate :: codegen:: HttpCommand :: HttpQuery ,
591607 HttpCommand :: ResultSet { .. } => crate :: codegen:: HttpCommand :: HttpResultSet ,
592- HttpCommand :: Error { .. } => crate :: codegen:: HttpCommand :: HttpError ,
608+ HttpCommand :: CloseConnection { .. } | HttpCommand :: Error { .. } => {
609+ crate :: codegen:: HttpCommand :: HttpError
610+ }
593611 } ,
594612 command : match & self . command {
595613 HttpCommand :: Query {
@@ -614,7 +632,7 @@ impl HttpMessage {
614632 . as_union_value ( ) ,
615633 )
616634 }
617- HttpCommand :: Error { error } => {
635+ HttpCommand :: Error { error } | HttpCommand :: CloseConnection { error } => {
618636 let error_offset = builder. create_string ( & error) ;
619637 Some (
620638 HttpError :: create (
@@ -653,6 +671,10 @@ impl HttpMessage {
653671 builder. finished_data ( ) . to_vec ( ) // TODO copy
654672 }
655673
674+ pub fn should_close_connection ( & self ) -> bool {
675+ matches ! ( self . command, HttpCommand :: CloseConnection { .. } )
676+ }
677+
656678 fn build_columns < ' a : ' ma , ' ma > (
657679 builder : & ' ma mut FlatBufferBuilder < ' a > ,
658680 columns : & Vec < Column > ,
@@ -981,14 +1003,20 @@ mod tests {
9811003 async fn exec_query_with_context (
9821004 & self ,
9831005 _context : SqlQueryContext ,
984- _query : & str ,
1006+ query : & str ,
9851007 ) -> Result < Arc < DataFrame > , CubeError > {
9861008 tokio:: time:: sleep ( Duration :: from_secs ( 2 ) ) . await ;
9871009 let counter = self . message_counter . fetch_add ( 1 , Ordering :: Relaxed ) ;
988- Ok ( Arc :: new ( DataFrame :: new (
989- vec ! [ Column :: new( "foo" . to_string( ) , ColumnType :: String , 0 ) ] ,
990- vec ! [ Row :: new( vec![ TableValue :: String ( format!( "{}" , counter) ) ] ) ] ,
991- ) ) )
1010+ if query == "close_connection" {
1011+ Err ( CubeError :: wrong_connection ( "wrong connection" . to_string ( ) ) )
1012+ } else if query == "error" {
1013+ Err ( CubeError :: internal ( "error" . to_string ( ) ) )
1014+ } else {
1015+ Ok ( Arc :: new ( DataFrame :: new (
1016+ vec ! [ Column :: new( "foo" . to_string( ) , ColumnType :: String , 0 ) ] ,
1017+ vec ! [ Row :: new( vec![ TableValue :: String ( format!( "{}" , counter) ) ] ) ] ,
1018+ ) ) )
1019+ }
9921020 }
9931021
9941022 async fn plan_query ( & self , _query : & str ) -> Result < QueryPlans , CubeError > {
@@ -1046,19 +1074,25 @@ mod tests {
10461074
10471075 tokio:: time:: sleep ( Duration :: from_secs ( 1 ) ) . await ;
10481076
1049- async fn connect_and_send (
1050- message_id : u32 ,
1051- connection_id : Option < String > ,
1052- ) -> WebSocketStream < MaybeTlsStream < TcpStream > > {
1053- let ( mut socket, _) = connect_async ( Url :: parse ( "ws://127.0.0.1:53031/ws" ) . unwrap ( ) )
1077+ async fn connect ( ) -> WebSocketStream < MaybeTlsStream < TcpStream > > {
1078+ let ( socket, _) = connect_async ( Url :: parse ( "ws://127.0.0.1:53031/ws" ) . unwrap ( ) )
10541079 . await
10551080 . unwrap ( ) ;
10561081 socket
1082+ }
1083+
1084+ async fn send_query (
1085+ socket : & mut WebSocketStream < MaybeTlsStream < TcpStream > > ,
1086+ message_id : u32 ,
1087+ connection_id : Option < String > ,
1088+ query : & str ,
1089+ ) {
1090+ socket
10571091 . send ( Message :: binary (
10581092 HttpMessage {
10591093 message_id,
10601094 command : HttpCommand :: Query {
1061- query : "foo" . to_string ( ) ,
1095+ query : query . to_string ( ) ,
10621096 inline_tables : vec ! [ ] ,
10631097 trace_obj : None ,
10641098 } ,
@@ -1068,9 +1102,25 @@ mod tests {
10681102 ) )
10691103 . await
10701104 . unwrap ( ) ;
1105+ }
1106+
1107+ async fn connect_and_send_query (
1108+ message_id : u32 ,
1109+ connection_id : Option < String > ,
1110+ query : & str ,
1111+ ) -> WebSocketStream < MaybeTlsStream < TcpStream > > {
1112+ let mut socket = connect ( ) . await ;
1113+ send_query ( & mut socket, message_id, connection_id, query) . await ;
10711114 socket
10721115 }
10731116
1117+ async fn connect_and_send (
1118+ message_id : u32 ,
1119+ connection_id : Option < String > ,
1120+ ) -> WebSocketStream < MaybeTlsStream < TcpStream > > {
1121+ connect_and_send_query ( message_id, connection_id, "foo" ) . await
1122+ }
1123+
10741124 async fn assert_message (
10751125 socket : & mut WebSocketStream < MaybeTlsStream < TcpStream > > ,
10761126 counter : & str ,
@@ -1150,6 +1200,25 @@ mod tests {
11501200 } ,
11511201 ) ;
11521202
1203+ tokio:: time:: sleep ( Duration :: from_millis ( 2500 ) ) . await ;
1204+ let mut socket = connect_and_send ( 3 , Some ( "foo" . to_string ( ) ) ) . await ;
1205+ assert_message ( & mut socket, "6" ) . await ;
1206+
1207+ let mut socket2 = connect_and_send ( 3 , Some ( "foo2" . to_string ( ) ) ) . await ;
1208+ assert_message ( & mut socket2, "7" ) . await ;
1209+
1210+ send_query ( & mut socket, 3 , Some ( "foo" . to_string ( ) ) , "close_connection" ) . await ;
1211+ socket. next ( ) . await . unwrap ( ) . unwrap ( ) ;
1212+
1213+ send_query ( & mut socket2, 3 , Some ( "foo" . to_string ( ) ) , "error" ) . await ;
1214+ socket2. next ( ) . await . unwrap ( ) . unwrap ( ) ;
1215+
1216+ send_query ( & mut socket, 3 , Some ( "foo" . to_string ( ) ) , "foo" ) . await ;
1217+ assert ! ( socket. next( ) . await . unwrap( ) . is_err( ) ) ;
1218+
1219+ let mut socket2 = connect_and_send ( 3 , Some ( "foo2" . to_string ( ) ) ) . await ;
1220+ assert_message ( & mut socket2, "10" ) . await ;
1221+
11531222 http_server. stop_processing ( ) . await ;
11541223 }
11551224}
0 commit comments