11package cli
22
33import (
4+ "bufio"
45 "context"
56 "embed"
67 "encoding/json"
8+ "fmt"
79 "io/fs"
810 "net"
911 "net/http"
@@ -23,6 +25,33 @@ import (
2325//go:embed static/*
2426var static embed.FS
2527
28+ type responseRecorder struct {
29+ http.ResponseWriter
30+ headerWritten bool
31+ logger slog.Logger
32+ }
33+
34+ // Implement Hijacker interface for WebSocket support
35+ func (r * responseRecorder ) Hijack () (net.Conn , * bufio.ReadWriter , error ) {
36+ if hijacker , ok := r .ResponseWriter .(http.Hijacker ); ok {
37+ return hijacker .Hijack ()
38+ }
39+ return nil , nil , fmt .Errorf ("responseRecorder does not implement http.Hijacker" )
40+ }
41+
42+ // Wrap your handler
43+ func debugMiddleware (logger slog.Logger ) func (http.Handler ) http.Handler {
44+ return func (next http.Handler ) http.Handler {
45+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
46+ recorder := & responseRecorder {
47+ ResponseWriter : w ,
48+ logger : logger ,
49+ }
50+ next .ServeHTTP (recorder , r )
51+ })
52+ }
53+ }
54+
2655func (r * RootCmd ) WebsocketServer () * serpent.Command {
2756 var (
2857 address string
@@ -48,6 +77,9 @@ func (r *RootCmd) WebsocketServer() *serpent.Command {
4877 logger := slog .Make (sloghuman .Sink (i .Stderr )).Leveled (slog .LevelDebug )
4978
5079 mux := chi .NewMux ()
80+
81+ mux .Use (debugMiddleware (logger ))
82+
5183 mux .HandleFunc ("/directories" , func (rw http.ResponseWriter , r * http.Request ) {
5284 entries , err := os .ReadDir ("." )
5385 if err != nil {
@@ -100,24 +132,48 @@ func (r *RootCmd) WebsocketServer() *serpent.Command {
100132
101133func websocketHandler (logger slog.Logger ) func (rw http.ResponseWriter , r * http.Request ) {
102134 return func (rw http.ResponseWriter , r * http.Request ) {
103- conn , err := websocket .Accept (rw , r , nil )
104- if err != nil {
105- http .Error (rw , "Could not accept websocket connection" , http .StatusInternalServerError )
106- return
107- }
108-
135+
136+ logger .Debug (r .Context (), "WebSocket connection attempt" ,
137+ slog .F ("remote_addr" , r .RemoteAddr ),
138+ slog .F ("path" , r .URL .Path ),
139+ slog .F ("query" , r .URL .RawQuery ))
140+
141+ // Validate all parameters BEFORE upgrading the connection
109142 dir := chi .URLParam (r , "dir" )
143+ logger .Debug (r .Context (), "Directory parameter" , slog .F ("dir" , dir ))
144+
110145 dinfo , err := os .Stat (dir )
111146 if err != nil {
112- _ = conn .Close (websocket .StatusInternalError , "Could not stat directory" )
147+ logger .Error (r .Context (), "Directory validation failed" ,
148+ slog .Error (err ),
149+ slog .F ("dir" , dir ))
150+ http .Error (rw , "Could not stat directory: " + err .Error (), http .StatusBadRequest )
113151 return
114152 }
115153
116154 if ! dinfo .IsDir () {
117- _ = conn . Close ( websocket . StatusInternalError , "Not a directory" )
155+ http . Error ( rw , "Not a directory" , http . StatusBadRequest )
118156 return
119157 }
120158
159+ // Log before WebSocket upgrade
160+ logger .Debug (r .Context (), "Attempting WebSocket upgrade" )
161+
162+ // Create WebSocket options with proper origin check
163+ options := & websocket.AcceptOptions {
164+ OriginPatterns : []string {
165+ "*" ,
166+ },
167+ }
168+
169+ conn , err := websocket .Accept (rw , r , options )
170+ if err != nil {
171+ logger .Error (r .Context (), "WebSocket upgrade failed" , slog .Error (err ))
172+ http .Error (rw , "Could not accept websocket connection: " + err .Error (), http .StatusInternalServerError )
173+ return
174+ }
175+ logger .Debug (r .Context (), "WebSocket connection established" )
176+
121177 dirFS := os .DirFS (dir )
122178 planPath := r .URL .Query ().Get ("plan" )
123179
0 commit comments