Skip to content

Commit e574c0d

Browse files
authored
fix(middleware/cors): CORS handling (#2937)
* fix(middleware/cors): CORS handling * fix(middleware/cors): Vary header handling * test(middleware/cors): Ensure Vary Headers checked
1 parent 43d5091 commit e574c0d

File tree

2 files changed

+70
-46
lines changed

2 files changed

+70
-46
lines changed

middleware/cors/cors.go

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,19 @@ func New(config ...Config) fiber.Handler {
162162
// Get originHeader header
163163
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))
164164

165-
// If the request does not have Origin and Access-Control-Request-Method
166-
// headers, the request is outside the scope of CORS
167-
if originHeader == "" || c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
165+
// If the request does not have Origin header, the request is outside the scope of CORS
166+
if originHeader == "" {
167+
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
168+
// Unless all origins are allowed, we include the Vary header to cache the response correctly
169+
if !allowAllOrigins {
170+
c.Vary(fiber.HeaderOrigin)
171+
}
172+
173+
return c.Next()
174+
}
175+
176+
// If it's a preflight request and doesn't have Access-Control-Request-Method header, it's outside the scope of CORS
177+
if c.Method() == fiber.MethodOptions && c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
168178
return c.Next()
169179
}
170180

@@ -204,13 +214,23 @@ func New(config ...Config) fiber.Handler {
204214
// Simple request
205215
// Ommit allowMethods and allowHeaders, only used for pre-flight requests
206216
if c.Method() != fiber.MethodOptions {
217+
if !allowAllOrigins {
218+
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
219+
c.Vary(fiber.HeaderOrigin)
220+
}
207221
setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg)
208222
return c.Next()
209223
}
210224

211-
// Preflight request
225+
// Pre-flight request
226+
227+
// Response to OPTIONS request should not be cached but,
228+
// some caching can be configured to cache such responses.
229+
// To Avoid poisoning the cache, we include the Vary header
230+
// of preflight responses:
212231
c.Vary(fiber.HeaderAccessControlRequestMethod)
213232
c.Vary(fiber.HeaderAccessControlRequestHeaders)
233+
c.Vary(fiber.HeaderOrigin)
214234

215235
setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)
216236

@@ -221,8 +241,6 @@ func New(config ...Config) fiber.Handler {
221241

222242
// Function to set CORS headers
223243
func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) {
224-
c.Vary(fiber.HeaderOrigin)
225-
226244
if cfg.AllowCredentials {
227245
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
228246
if allowOrigin == "*" {

middleware/cors/cors_test.go

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
5050
// Test default GET response headers
5151
ctx := &fasthttp.RequestCtx{}
5252
ctx.Request.Header.SetMethod(fiber.MethodGet)
53-
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
5453
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
5554
h(ctx)
5655

@@ -70,6 +69,44 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
7069
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
7170
}
7271

72+
func Test_CORS_AllowOrigins_Vary(t *testing.T) {
73+
t.Parallel()
74+
app := fiber.New()
75+
app.Use(New(
76+
Config{
77+
AllowOrigins: "http://localhost",
78+
},
79+
))
80+
81+
h := app.Handler()
82+
83+
// Test Vary header non-Cors request
84+
ctx := &fasthttp.RequestCtx{}
85+
ctx.Request.Header.SetMethod(fiber.MethodGet)
86+
h(ctx)
87+
utils.AssertEqual(t, true, strings.Contains(string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin), "Vary header should be set for Origin")
88+
89+
// Test Vary header Cors preflight request
90+
ctx.Request.Reset()
91+
ctx.Response.Reset()
92+
ctx.Request.Header.SetMethod(fiber.MethodOptions)
93+
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
94+
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
95+
h(ctx)
96+
vh := string(ctx.Response.Header.Peek(fiber.HeaderVary))
97+
utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderOrigin), "Vary header should be set for Origin")
98+
utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderAccessControlRequestMethod), "Vary header should be set for Access-Control-Request-Method")
99+
utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderAccessControlRequestHeaders), "Vary header should be set for Access-Control-Request-Headers")
100+
101+
// Test Vary header Cors request
102+
ctx.Request.Reset()
103+
ctx.Response.Reset()
104+
ctx.Request.Header.SetMethod(fiber.MethodGet)
105+
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
106+
h(ctx)
107+
utils.AssertEqual(t, true, strings.Contains(string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin), "Vary header should be set for Origin")
108+
}
109+
73110
// go test -run -v Test_CORS_Wildcard
74111
func Test_CORS_Wildcard(t *testing.T) {
75112
t.Parallel()
@@ -97,6 +134,10 @@ func Test_CORS_Wildcard(t *testing.T) {
97134

98135
// Check result
99136
utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) // Validates request is not reflecting origin in the response
137+
vh := string(ctx.Response.Header.Peek(fiber.HeaderVary))
138+
utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderOrigin), "Vary header should be set for Origin")
139+
utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderAccessControlRequestMethod), "Vary header should be set for Access-Control-Request-Method")
140+
utils.AssertEqual(t, true, strings.Contains(vh, fiber.HeaderAccessControlRequestHeaders), "Vary header should be set for Access-Control-Request-Headers")
100141
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
101142
utils.AssertEqual(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
102143
utils.AssertEqual(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
@@ -105,9 +146,9 @@ func Test_CORS_Wildcard(t *testing.T) {
105146
ctx = &fasthttp.RequestCtx{}
106147
ctx.Request.Header.SetMethod(fiber.MethodGet)
107148
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
108-
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
109149
handler(ctx)
110150

151+
utils.AssertEqual(t, false, strings.Contains(string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin), "Vary header should not be set for Origin")
111152
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
112153
utils.AssertEqual(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
113154
}
@@ -147,7 +188,6 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) {
147188
// Test non OPTIONS (preflight) response headers
148189
ctx = &fasthttp.RequestCtx{}
149190
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
150-
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
151191
ctx.Request.Header.SetMethod(fiber.MethodGet)
152192
handler(ctx)
153193

@@ -466,7 +506,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
466506
// Get handler pointer
467507
handler := app.Handler()
468508

469-
t.Run("Without origin and Access-Control-Request-Method headers", func(t *testing.T) {
509+
t.Run("Without origin", func(t *testing.T) {
470510
t.Parallel()
471511
// Make request without origin header, and without Access-Control-Request-Method
472512
for _, method := range methods {
@@ -479,34 +519,6 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
479519
}
480520
})
481521

482-
t.Run("With origin but without Access-Control-Request-Method header", func(t *testing.T) {
483-
t.Parallel()
484-
// Make request with origin header, but without Access-Control-Request-Method
485-
for _, method := range methods {
486-
ctx := &fasthttp.RequestCtx{}
487-
ctx.Request.Header.SetMethod(method)
488-
ctx.Request.SetRequestURI("https://example.com/")
489-
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
490-
handler(ctx)
491-
utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
492-
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
493-
}
494-
})
495-
496-
t.Run("Without origin but with Access-Control-Request-Method header", func(t *testing.T) {
497-
t.Parallel()
498-
// Make request without origin header, but with Access-Control-Request-Method
499-
for _, method := range methods {
500-
ctx := &fasthttp.RequestCtx{}
501-
ctx.Request.Header.SetMethod(method)
502-
ctx.Request.SetRequestURI("https://example.com/")
503-
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
504-
handler(ctx)
505-
utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
506-
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
507-
}
508-
})
509-
510522
t.Run("Preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
511523
t.Parallel()
512524
// Make preflight request with origin header and with Access-Control-Request-Method
@@ -524,15 +536,14 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
524536
}
525537
})
526538

527-
t.Run("Non-preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
539+
t.Run("Non-preflight request with origin", func(t *testing.T) {
528540
t.Parallel()
529541
// Make non-preflight request with origin header and with Access-Control-Request-Method
530542
for _, method := range methods {
531543
ctx := &fasthttp.RequestCtx{}
532544
ctx.Request.Header.SetMethod(method)
533545
ctx.Request.SetRequestURI("https://example.com/api/action")
534546
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
535-
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method)
536547
handler(ctx)
537548
utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
538549
utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set")
@@ -901,7 +912,6 @@ func Benchmark_CORS_NewHandler(b *testing.B) {
901912
req.Header.SetMethod(fiber.MethodGet)
902913
req.SetRequestURI("/")
903914
req.Header.Set(fiber.HeaderOrigin, "http://localhost")
904-
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
905915
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
906916

907917
ctx.Init(req, nil, nil)
@@ -942,7 +952,6 @@ func Benchmark_CORS_NewHandlerParallel(b *testing.B) {
942952
req.Header.SetMethod(fiber.MethodGet)
943953
req.SetRequestURI("/")
944954
req.Header.Set(fiber.HeaderOrigin, "http://localhost")
945-
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
946955
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
947956

948957
ctx.Init(req, nil, nil)
@@ -976,7 +985,6 @@ func Benchmark_CORS_NewHandlerSingleOrigin(b *testing.B) {
976985
req.Header.SetMethod(fiber.MethodGet)
977986
req.SetRequestURI("/")
978987
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
979-
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
980988
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
981989

982990
ctx.Init(req, nil, nil)
@@ -1017,7 +1025,6 @@ func Benchmark_CORS_NewHandlerSingleOriginParallel(b *testing.B) {
10171025
req.Header.SetMethod(fiber.MethodGet)
10181026
req.SetRequestURI("/")
10191027
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
1020-
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
10211028
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
10221029

10231030
ctx.Init(req, nil, nil)
@@ -1051,7 +1058,6 @@ func Benchmark_CORS_NewHandlerWildcard(b *testing.B) {
10511058
req.Header.SetMethod(fiber.MethodGet)
10521059
req.SetRequestURI("/")
10531060
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
1054-
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
10551061
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
10561062

10571063
ctx.Init(req, nil, nil)
@@ -1092,7 +1098,6 @@ func Benchmark_CORS_NewHandlerWildcardParallel(b *testing.B) {
10921098
req.Header.SetMethod(fiber.MethodGet)
10931099
req.SetRequestURI("/")
10941100
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
1095-
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
10961101
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
10971102

10981103
ctx.Init(req, nil, nil)
@@ -1122,6 +1127,7 @@ func Benchmark_CORS_NewHandlerPreflight(b *testing.B) {
11221127
h := app.Handler()
11231128
ctx := &fasthttp.RequestCtx{}
11241129

1130+
// Preflight request
11251131
req := &fasthttp.Request{}
11261132
req.Header.SetMethod(fiber.MethodOptions)
11271133
req.SetRequestURI("/")

0 commit comments

Comments
 (0)