Skip to content

Commit 86906cb

Browse files
authored
fix(csrf): update middleware usage to comply with gorilla/csrf changes (#4343)
1 parent 358e13b commit 86906cb

File tree

8 files changed

+156
-18
lines changed

8 files changed

+156
-18
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,4 @@ build/mage_output_file.go
6060
!**/**/testdata/*.rego
6161

6262
release-notes.md
63+
config/dev.yml

config/flipt.schema.cue

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ import "list"
4242
state_lifetime: =~#duration | *"10m"
4343
csrf?: {
4444
key: string
45+
secure?: bool
4546
}
4647
}
4748

config/flipt.schema.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@
8484
"csrf": {
8585
"type": "object",
8686
"properties": {
87-
"key": { "type": "string" }
87+
"key": { "type": "string" },
88+
"secure": { "type": "boolean" }
8889
},
8990
"required": []
9091
}
@@ -1426,8 +1427,7 @@
14261427
"experimental": {
14271428
"type": "object",
14281429
"additionalProperties": false,
1429-
"properties": {
1430-
},
1430+
"properties": {},
14311431
"title": "Experimental"
14321432
}
14331433
}

go.work.sum

Lines changed: 119 additions & 0 deletions
Large diffs are not rendered by default.

internal/cmd/http.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,14 @@ func NewHTTPServer(
173173
r = csrf.UnsafeSkipCheck(r)
174174
}
175175

176+
if !cfg.Authentication.Session.CSRF.Secure {
177+
r = csrf.PlaintextHTTPRequest(r)
178+
}
179+
176180
handler.ServeHTTP(w, r)
177181
})
178182
})
179-
r.Use(csrf.Protect([]byte(key), csrf.Path("/")))
183+
r.Use(csrf.Protect([]byte(key), csrf.Path("/"), csrf.Secure(cfg.Authentication.Session.CSRF.Secure)))
180184
}
181185

182186
r.Mount("/api/v1", api)

internal/config/authentication.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ func (c *AuthenticationConfig) setDefaults(v *viper.Viper) error {
134134
"session": map[string]any{
135135
"token_lifetime": "24h",
136136
"state_lifetime": "10m",
137+
"csrf": map[string]any{
138+
"secure": true,
139+
},
137140
},
138141
"methods": methods,
139142
})
@@ -241,6 +244,8 @@ type AuthenticationSession struct {
241244
type AuthenticationSessionCSRF struct {
242245
// Key is the private key string used to authenticate csrf tokens.
243246
Key string `json:"-" mapstructure:"key"`
247+
// Secure signals to the CSRF middleware that the request is being served over TLS or plaintext HTTP
248+
Secure bool `json:"secure,omitempty" mapstructure:"secure" yaml:"secure,omitempty"`
244249
}
245250

246251
// AuthenticationMethods is a set of configuration for each authentication

internal/config/config.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,8 @@ func stringToEnumHookFunc[T constraints.Integer](mappings map[string]T) mapstruc
438438
return func(
439439
f reflect.Type,
440440
t reflect.Type,
441-
data interface{}) (interface{}, error) {
441+
data interface{},
442+
) (interface{}, error) {
442443
if f.Kind() != reflect.String {
443444
return data, nil
444445
}
@@ -456,7 +457,8 @@ func experimentalFieldSkipHookFunc(types ...reflect.Type) mapstructure.DecodeHoo
456457
return func(
457458
f reflect.Type,
458459
t reflect.Type,
459-
data interface{}) (interface{}, error) {
460+
data interface{},
461+
) (interface{}, error) {
460462
if len(types) == 0 {
461463
return data, nil
462464
}
@@ -482,7 +484,8 @@ func stringToEnvsubstHookFunc() mapstructure.DecodeHookFunc {
482484
return func(
483485
f reflect.Type,
484486
t reflect.Type,
485-
data interface{}) (interface{}, error) {
487+
data interface{},
488+
) (interface{}, error) {
486489
if f.Kind() != reflect.String || f != reflect.TypeOf("") {
487490
return data, nil
488491
}
@@ -501,7 +504,8 @@ func stringToSliceHookFunc() mapstructure.DecodeHookFunc {
501504
return func(
502505
f reflect.Kind,
503506
t reflect.Kind,
504-
data interface{}) (interface{}, error) {
507+
data interface{},
508+
) (interface{}, error) {
505509
if f != reflect.String || t != reflect.Slice {
506510
return data, nil
507511
}
@@ -637,6 +641,9 @@ func Default() *Config {
637641
Session: AuthenticationSession{
638642
TokenLifetime: 24 * time.Hour,
639643
StateLifetime: 10 * time.Minute,
644+
CSRF: AuthenticationSessionCSRF{
645+
Secure: true,
646+
},
640647
},
641648
},
642649

internal/config/config_test.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,8 @@ func TestLoad(t *testing.T) {
578578
},
579579
},
580580
},
581-
}}
581+
},
582+
}
582583
return cfg
583584
},
584585
},
@@ -695,7 +696,8 @@ func TestLoad(t *testing.T) {
695696
TokenLifetime: 24 * time.Hour,
696697
StateLifetime: 10 * time.Minute,
697698
CSRF: AuthenticationSessionCSRF{
698-
Key: "abcdefghijklmnopqrstuvwxyz1234567890", //gitleaks:allow
699+
Key: "abcdefghijklmnopqrstuvwxyz1234567890", //gitleaks:allow
700+
Secure: true,
699701
},
700702
},
701703
Methods: AuthenticationMethods{
@@ -869,7 +871,8 @@ func TestLoad(t *testing.T) {
869871
TokenLifetime: 24 * time.Hour,
870872
StateLifetime: 10 * time.Minute,
871873
CSRF: AuthenticationSessionCSRF{
872-
Key: "abcdefghijklmnopqrstuvwxyz1234567890", //gitleaks:allow
874+
Key: "abcdefghijklmnopqrstuvwxyz1234567890", //gitleaks:allow
875+
Secure: true,
873876
},
874877
},
875878
Methods: AuthenticationMethods{
@@ -1789,13 +1792,11 @@ func TestGetConfigFile(t *testing.T) {
17891792
}
17901793
}
17911794

1792-
var (
1793-
// add any struct tags to match their camelCase equivalents here.
1794-
camelCaseMatchers = map[string]string{
1795-
"requireTLS": "requireTLS",
1796-
"discoveryURL": "discoveryURL",
1797-
}
1798-
)
1795+
// add any struct tags to match their camelCase equivalents here.
1796+
var camelCaseMatchers = map[string]string{
1797+
"requireTLS": "requireTLS",
1798+
"discoveryURL": "discoveryURL",
1799+
}
17991800

18001801
func TestStructTags(t *testing.T) {
18011802
configType := reflect.TypeOf(Config{})

0 commit comments

Comments
 (0)