Skip to content

Commit 26ab188

Browse files
committed
CORS: add an optional custom function to validate the origin
1 parent 17a5fca commit 26ab188

File tree

2 files changed

+91
-27
lines changed

2 files changed

+91
-27
lines changed

middleware/cors.go

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ type (
1919
// Optional. Default value []string{"*"}.
2020
AllowOrigins []string `yaml:"allow_origins"`
2121

22+
// AllowOriginFunc is a custom function to validate the origin. It takes the
23+
// origin as an argument and returns true if allowed or false otherwise. If
24+
// an error is returned, it is returned by the handler. If this option is
25+
// set, AllowOrigins is ignored.
26+
// Optional.
27+
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
28+
2229
// AllowMethods defines a list methods allowed when accessing the resource.
2330
// This is used in response to a preflight request.
2431
// Optional. Default value DefaultCORSConfig.AllowMethods.
@@ -113,40 +120,50 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
113120
return c.NoContent(http.StatusNoContent)
114121
}
115122

116-
// Check allowed origins
117-
for _, o := range config.AllowOrigins {
118-
if o == "*" && config.AllowCredentials {
119-
allowOrigin = origin
120-
break
121-
}
122-
if o == "*" || o == origin {
123-
allowOrigin = o
124-
break
125-
}
126-
if matchSubdomain(origin, o) {
127-
allowOrigin = origin
128-
break
129-
}
130-
}
131-
132-
// Check allowed origin patterns
133-
for _, re := range allowOriginPatterns {
134-
if allowOrigin == "" {
135-
didx := strings.Index(origin, "://")
136-
if didx == -1 {
137-
continue
123+
if config.AllowOriginFunc == nil {
124+
// Check allowed origins
125+
for _, o := range config.AllowOrigins {
126+
if o == "*" && config.AllowCredentials {
127+
allowOrigin = origin
128+
break
138129
}
139-
domAuth := origin[didx+3:]
140-
// to avoid regex cost by invalid long domain
141-
if len(domAuth) > 253 {
130+
if o == "*" || o == origin {
131+
allowOrigin = o
142132
break
143133
}
144-
145-
if match, _ := regexp.MatchString(re, origin); match {
134+
if matchSubdomain(origin, o) {
146135
allowOrigin = origin
147136
break
148137
}
149138
}
139+
140+
// Check allowed origin patterns
141+
for _, re := range allowOriginPatterns {
142+
if allowOrigin == "" {
143+
didx := strings.Index(origin, "://")
144+
if didx == -1 {
145+
continue
146+
}
147+
domAuth := origin[didx+3:]
148+
// to avoid regex cost by invalid long domain
149+
if len(domAuth) > 253 {
150+
break
151+
}
152+
153+
if match, _ := regexp.MatchString(re, origin); match {
154+
allowOrigin = origin
155+
break
156+
}
157+
}
158+
}
159+
} else {
160+
allowed, err := config.AllowOriginFunc(origin)
161+
if err != nil {
162+
return err
163+
}
164+
if allowed {
165+
allowOrigin = origin
166+
}
150167
}
151168

152169
// Origin not allowed

middleware/cors_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package middleware
22

33
import (
4+
"errors"
45
"net/http"
56
"net/http/httptest"
67
"testing"
@@ -360,3 +361,49 @@ func TestCorsHeaders(t *testing.T) {
360361
}
361362
}
362363
}
364+
365+
func Test_allowOriginFunc(t *testing.T) {
366+
returnTrue := func(origin string) (bool, error) {
367+
return true, nil
368+
}
369+
returnFalse := func(origin string) (bool, error) {
370+
return false, nil
371+
}
372+
returnError := func(origin string) (bool, error) {
373+
return true, errors.New("this is a test error")
374+
}
375+
376+
allowOriginFuncs := []func(origin string) (bool, error){
377+
returnTrue,
378+
returnFalse,
379+
returnError,
380+
}
381+
382+
const origin = "http://example.com"
383+
384+
e := echo.New()
385+
for _, allowOriginFunc := range allowOriginFuncs {
386+
req := httptest.NewRequest(http.MethodOptions, "/", nil)
387+
rec := httptest.NewRecorder()
388+
c := e.NewContext(req, rec)
389+
req.Header.Set(echo.HeaderOrigin, origin)
390+
cors := CORSWithConfig(CORSConfig{
391+
AllowOriginFunc: allowOriginFunc,
392+
})
393+
h := cors(echo.NotFoundHandler)
394+
err := h(c)
395+
396+
expected, expectedErr := allowOriginFunc(origin)
397+
if expectedErr != nil {
398+
assert.Equal(t, expectedErr, err)
399+
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
400+
continue
401+
}
402+
403+
if expected {
404+
assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
405+
} else {
406+
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
407+
}
408+
}
409+
}

0 commit comments

Comments
 (0)