@@ -10,6 +10,7 @@ import (
1010 "net"
1111 "net/http"
1212 "os"
13+ "strings"
1314 "sync"
1415 "sync/atomic"
1516 "time"
@@ -295,7 +296,100 @@ func (s *SshClient) httpTunnel(src, dst net.Conn) {
295296 return
296297 }
297298
298- // read and replace response body
299+ // Handle WebSocket upgrades and SSE streams with TCP tunneling
300+ if response .StatusCode == http .StatusSwitchingProtocols {
301+ // WebSocket upgrade - write response headers and switch to TCP tunneling
302+ err = response .Write (srcWriter )
303+ if err != nil {
304+ if s .config .Debug {
305+ s .logDebug ("Failed to write WebSocket upgrade response" , err )
306+ }
307+ return
308+ }
309+ srcWriter .Flush ()
310+
311+ // Drain any bytes already buffered post-handshake to avoid loss when switching to raw TCP
312+ if n := dstReader .Buffered (); n > 0 {
313+ buf := make ([]byte , n )
314+ if _ , err := io .ReadFull (dstReader , buf ); err == nil {
315+ if _ , err := srcWriter .Write (buf ); err != nil {
316+ if s .config .Debug {
317+ s .logDebug ("Failed to flush buffered server bytes on WS upgrade" , err )
318+ }
319+ return
320+ }
321+ srcWriter .Flush ()
322+ }
323+ }
324+
325+ if n := srcReader .Buffered (); n > 0 {
326+ buf := make ([]byte , n )
327+ if _ , err := io .ReadFull (srcReader , buf ); err == nil {
328+ if _ , err := dstWriter .Write (buf ); err != nil {
329+ if s .config .Debug {
330+ s .logDebug ("Failed to flush buffered client bytes on WS upgrade" , err )
331+ }
332+ return
333+ }
334+ dstWriter .Flush ()
335+ }
336+ }
337+
338+ s .tcpTunnel (src , dst )
339+ return
340+ }
341+
342+ // Check for SSE (Server-Sent Events) streams
343+ contentType := response .Header .Get ("Content-Type" )
344+ if strings .Contains (contentType , "text/event-stream" ) {
345+ // Ensure SSE response body is closed when streaming finishes or on error
346+ defer response .Body .Close ()
347+
348+ // SSE stream - copy the response body in real-time without buffering
349+ // Write status line and headers first
350+ fmt .Fprintf (srcWriter , "%s %s\r \n " , response .Proto , response .Status )
351+
352+ // Write headers, excluding Content-Length and Transfer-Encoding
353+ // as we'll be streaming the body directly
354+ for key , values := range response .Header {
355+ if key == "Content-Length" || key == "Transfer-Encoding" {
356+ continue
357+ }
358+ for _ , value := range values {
359+ fmt .Fprintf (srcWriter , "%s: %s\r \n " , key , value )
360+ }
361+ }
362+
363+ // Empty line to end headers
364+ fmt .Fprintf (srcWriter , "\r \n " )
365+ srcWriter .Flush ()
366+
367+ // Stream the body with immediate flushing for real-time delivery
368+ buf := make ([]byte , 32 * 1024 ) // 32KB buffer
369+ for {
370+ n , err := response .Body .Read (buf )
371+ if n > 0 {
372+ _ , writeErr := srcWriter .Write (buf [:n ])
373+ if writeErr != nil {
374+ if s .config .Debug {
375+ s .logDebug ("Failed to write SSE data" , writeErr )
376+ }
377+ return
378+ }
379+ // Flush immediately to ensure real-time streaming
380+ srcWriter .Flush ()
381+ }
382+ if err != nil {
383+ if err != io .EOF && s .config .Debug {
384+ s .logDebug ("SSE stream ended" , err )
385+ }
386+ break
387+ }
388+ }
389+ return
390+ }
391+
392+ // read and replace response body for regular HTTP responses
299393 responseBody , err := io .ReadAll (response .Body )
300394 if err != nil {
301395 if s .config .Debug {
@@ -315,12 +409,6 @@ func (s *SshClient) httpTunnel(src, dst net.Conn) {
315409 }
316410 srcWriter .Flush ()
317411
318- if response .StatusCode == http .StatusSwitchingProtocols {
319- // handle websocket
320- s .tcpTunnel (src , dst )
321- return
322- }
323-
324412 s .logHttpRequest (ulid .Make ().String (), request , requestBody , response , responseBody )
325413}
326414
0 commit comments