Skip to content

Commit ae74722

Browse files
committed
chore: error handling for websocket
1 parent b6e3651 commit ae74722

File tree

1 file changed

+64
-8
lines changed

1 file changed

+64
-8
lines changed

cli/web.go

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package cli
22

33
import (
4+
"bufio"
45
"context"
56
"embed"
67
"encoding/json"
8+
"fmt"
79
"io/fs"
810
"net"
911
"net/http"
@@ -23,6 +25,33 @@ import (
2325
//go:embed static/*
2426
var static embed.FS
2527

28+
type responseRecorder struct {
29+
http.ResponseWriter
30+
headerWritten bool
31+
logger slog.Logger
32+
}
33+
34+
// Implement Hijacker interface for WebSocket support
35+
func (r *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
36+
if hijacker, ok := r.ResponseWriter.(http.Hijacker); ok {
37+
return hijacker.Hijack()
38+
}
39+
return nil, nil, fmt.Errorf("responseRecorder does not implement http.Hijacker")
40+
}
41+
42+
// Wrap your handler
43+
func debugMiddleware(logger slog.Logger) func(http.Handler) http.Handler {
44+
return func(next http.Handler) http.Handler {
45+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
46+
recorder := &responseRecorder{
47+
ResponseWriter: w,
48+
logger: logger,
49+
}
50+
next.ServeHTTP(recorder, r)
51+
})
52+
}
53+
}
54+
2655
func (r *RootCmd) WebsocketServer() *serpent.Command {
2756
var (
2857
address string
@@ -48,6 +77,9 @@ func (r *RootCmd) WebsocketServer() *serpent.Command {
4877
logger := slog.Make(sloghuman.Sink(i.Stderr)).Leveled(slog.LevelDebug)
4978

5079
mux := chi.NewMux()
80+
81+
mux.Use(debugMiddleware(logger))
82+
5183
mux.HandleFunc("/directories", func(rw http.ResponseWriter, r *http.Request) {
5284
entries, err := os.ReadDir(".")
5385
if err != nil {
@@ -100,24 +132,48 @@ func (r *RootCmd) WebsocketServer() *serpent.Command {
100132

101133
func websocketHandler(logger slog.Logger) func(rw http.ResponseWriter, r *http.Request) {
102134
return func(rw http.ResponseWriter, r *http.Request) {
103-
conn, err := websocket.Accept(rw, r, nil)
104-
if err != nil {
105-
http.Error(rw, "Could not accept websocket connection", http.StatusInternalServerError)
106-
return
107-
}
108-
135+
136+
logger.Debug(r.Context(), "WebSocket connection attempt",
137+
slog.F("remote_addr", r.RemoteAddr),
138+
slog.F("path", r.URL.Path),
139+
slog.F("query", r.URL.RawQuery))
140+
141+
// Validate all parameters BEFORE upgrading the connection
109142
dir := chi.URLParam(r, "dir")
143+
logger.Debug(r.Context(), "Directory parameter", slog.F("dir", dir))
144+
110145
dinfo, err := os.Stat(dir)
111146
if err != nil {
112-
_ = conn.Close(websocket.StatusInternalError, "Could not stat directory")
147+
logger.Error(r.Context(), "Directory validation failed",
148+
slog.Error(err),
149+
slog.F("dir", dir))
150+
http.Error(rw, "Could not stat directory: "+err.Error(), http.StatusBadRequest)
113151
return
114152
}
115153

116154
if !dinfo.IsDir() {
117-
_ = conn.Close(websocket.StatusInternalError, "Not a directory")
155+
http.Error(rw, "Not a directory", http.StatusBadRequest)
118156
return
119157
}
120158

159+
// Log before WebSocket upgrade
160+
logger.Debug(r.Context(), "Attempting WebSocket upgrade")
161+
162+
// Create WebSocket options with proper origin check
163+
options := &websocket.AcceptOptions{
164+
OriginPatterns: []string{
165+
"*",
166+
},
167+
}
168+
169+
conn, err := websocket.Accept(rw, r, options)
170+
if err != nil {
171+
logger.Error(r.Context(), "WebSocket upgrade failed", slog.Error(err))
172+
http.Error(rw, "Could not accept websocket connection: "+err.Error(), http.StatusInternalServerError)
173+
return
174+
}
175+
logger.Debug(r.Context(), "WebSocket connection established")
176+
121177
dirFS := os.DirFS(dir)
122178
planPath := r.URL.Query().Get("plan")
123179

0 commit comments

Comments
 (0)