Skip to content

Commit dd33ecd

Browse files
committed
[server] Add CSRF protection with nosurf middleware
This commit integrates the nosurf middleware for CSRF protection, updates the DotFlush and DotReq structs to include a Token method for accessing CSRF tokens, and adjusts the request context handling to support the middleware.
1 parent d3d170b commit dd33ecd

File tree

5 files changed

+28
-11
lines changed

5 files changed

+28
-11
lines changed

dot_flush.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77
"net/http"
88
"strings"
99
"time"
10+
11+
"github.com/justinas/nosurf"
1012
)
1113

1214
type dotFlushProvider struct{}
@@ -18,7 +20,7 @@ func (dotFlushProvider) Value(r Request) (any, error) {
1820
if !ok {
1921
return &DotFlush{}, fmt.Errorf("response writer could not cast to http.Flusher")
2022
}
21-
return &DotFlush{flusher: f, serverCtx: r.ServerCtx, requestCtx: r.R.Context()}, nil
23+
return &DotFlush{flusher: f, serverCtx: r.ServerCtx, request: r.R}, nil
2224
}
2325

2426
func (dotFlushProvider) Cleanup(v any, err error) error {
@@ -37,8 +39,9 @@ type flusher interface {
3739

3840
// DotFlush is used as the .Flush field for flushing template handlers (SSE).
3941
type DotFlush struct {
40-
flusher flusher
41-
serverCtx, requestCtx context.Context
42+
flusher flusher
43+
serverCtx context.Context
44+
request *http.Request
4245
}
4346

4447
// SendSSE sends an sse message by formatting the provided args as an sse event:
@@ -105,7 +108,7 @@ func (f *DotFlush) Repeat(max_ ...int) <-chan int {
105108
loop:
106109
for {
107110
select {
108-
case <-f.requestCtx.Done():
111+
case <-f.request.Context().Done():
109112
break loop
110113
case <-f.serverCtx.Done():
111114
break loop
@@ -125,7 +128,7 @@ func (f *DotFlush) Repeat(max_ ...int) <-chan int {
125128
func (f *DotFlush) Sleep(ms int) (string, error) {
126129
select {
127130
case <-time.After(time.Duration(ms) * time.Millisecond):
128-
case <-f.requestCtx.Done():
131+
case <-f.request.Context().Done():
129132
return "", ReturnError{}
130133
case <-f.serverCtx.Done():
131134
return "", ReturnError{}
@@ -137,9 +140,13 @@ func (f *DotFlush) Sleep(ms int) (string, error) {
137140
// client or until the server closes.
138141
func (f *DotFlush) WaitForServerStop() (string, error) {
139142
select {
140-
case <-f.requestCtx.Done():
143+
case <-f.request.Context().Done():
141144
return "", ReturnError{}
142145
case <-f.serverCtx.Done():
143146
return "", nil
144147
}
145148
}
149+
150+
func (f *DotFlush) Token() string {
151+
return nosurf.Token(f.request)
152+
}

dot_req.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package xtemplate
33
import (
44
"context"
55
"net/http"
6+
7+
"github.com/justinas/nosurf"
68
)
79

810
type dotReqProvider struct{}
@@ -29,3 +31,7 @@ var _ DotConfig = dotReqProvider{}
2931
type DotReq struct {
3032
*http.Request
3133
}
34+
35+
func (d DotReq) Token() string {
36+
return nosurf.Token(d.Request)
37+
}

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ require (
1515
github.com/felixge/httpsnoop v1.0.4
1616
github.com/google/uuid v1.6.0
1717
github.com/infogulch/watch v0.2.0
18+
github.com/justinas/nosurf v1.2.0
1819
github.com/klauspost/compress v1.18.0
1920
github.com/microcosm-cc/bluemonday v1.0.27
2021
github.com/nats-io/nats-server/v2 v2.10.24

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,8 @@ github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0
308308
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
309309
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
310310
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
311+
github.com/justinas/nosurf v1.2.0 h1:yMs1bSRrNiwXk4AS6n8vL2Ssgpb9CB25T/4xrixaK0s=
312+
github.com/justinas/nosurf v1.2.0/go.mod h1:ALpWdSbuNGy2lZWtyXdjkYv4edL23oSEgfBT1gPJ5BQ=
311313
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
312314
github.com/klauspost/compress v1.12.3/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
313315
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=

instance.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020

2121
"github.com/Masterminds/sprig/v3"
2222
"github.com/felixge/httpsnoop"
23+
"github.com/justinas/nosurf"
2324
"github.com/spf13/afero"
2425

2526
"github.com/google/uuid"
@@ -251,17 +252,17 @@ func (instance *Instance) ServeHTTP(w http.ResponseWriter, r *http.Request) {
251252
log := instance.config.Logger.With(slog.Group("serve",
252253
slog.String("requestid", rid),
253254
))
254-
log.LogAttrs(r.Context(), slog.LevelDebug, "serving request",
255+
log.LogAttrs(ctx, slog.LevelDebug, "serving request",
255256
slog.String("user-agent", r.Header.Get("User-Agent")),
256257
slog.String("method", r.Method),
257258
slog.String("requestPath", r.URL.Path),
258259
)
259-
ctx = context.WithValue(ctx, loggerKey, log)
260260

261-
r = r.WithContext(ctx)
262-
metrics := httpsnoop.CaptureMetrics(instance.router, w, r)
261+
metrics := httpsnoop.CaptureMetricsFn(w, func(ww http.ResponseWriter) {
262+
nosurf.New(instance.router).ServeHTTP(ww, r.WithContext(context.WithValue(ctx, loggerKey, log)))
263+
})
263264

264-
log.LogAttrs(r.Context(), levelDebug2, "request served",
265+
log.LogAttrs(ctx, levelDebug2, "request served",
265266
slog.Group("response",
266267
slog.Duration("duration", metrics.Duration),
267268
slog.Int("statusCode", metrics.Code),

0 commit comments

Comments
 (0)