Skip to content

Commit 1f31354

Browse files
authored
feat: add CORS middleware (#46)
1 parent b3f6f7c commit 1f31354

File tree

2 files changed

+338
-0
lines changed

2 files changed

+338
-0
lines changed

middlewares/cors.go

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
package middlewares
2+
3+
import (
4+
"net/http"
5+
"strconv"
6+
"strings"
7+
"time"
8+
9+
"github.com/mojixcoder/kid"
10+
)
11+
12+
// CorsConfig is the config used to build CORS middleware.
13+
type CorsConfig struct {
14+
// AllowedOrigins specifies which origins can access the resource.
15+
// If "*" is in the list, all origins will be allowed.
16+
//
17+
// Defaults to ["*"]
18+
AllowedOrigins []string
19+
20+
// AllowOriginFunc is a custom function for validating the origin.
21+
// The origin will always be set and you don't need to check that in this function.
22+
//
23+
// If you set this function the rest of validation logic will be ignored.
24+
//
25+
// Defaults to nil.
26+
AllowOriginFunc func(c *kid.Context, origin string) bool
27+
28+
// AllowedMethods is the list of allowed HTTP methods.
29+
//
30+
// Defaults to ["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"].
31+
AllowedMethods []string
32+
33+
// AllowedHeaders is the list of the custom headers which are allowed to be sent.
34+
//
35+
// If "*" is in the list, all headers will be allowed.
36+
AllowedHeaders []string
37+
38+
// ExposedHeaders a list of headers that clients are allowed to access.
39+
//
40+
// Defaults to [].
41+
ExposedHeaders []string
42+
43+
// MaxAge is the maximum duration that the response to the preflight request can be cached before another call is made.
44+
// In second percision.
45+
//
46+
// Will not be used if 0.
47+
// Defaults to 0.
48+
MaxAge time.Duration
49+
50+
// AllowCredentials if true, cookies will be allowed to be included in cross-site HTTP requests.
51+
//
52+
// defaults to false.
53+
AllowCredentials bool
54+
55+
// AllowPrivateNetwork if true, allow requests from sites on “public” IP to this server on a “private” IP.
56+
//
57+
// defaults to false.
58+
AllowPrivateNetwork bool
59+
60+
// allowAllOrigins will be true if "*" is in allowed origins.
61+
allowAllOrigins bool
62+
}
63+
64+
// DefaultCorsConfig is the default CORS config.
65+
var DefaultCorsConfig = CorsConfig{
66+
AllowedOrigins: []string{"*"},
67+
AllowedMethods: []string{
68+
http.MethodGet, http.MethodPost, http.MethodPut,
69+
http.MethodPatch, http.MethodDelete, http.MethodOptions,
70+
},
71+
}
72+
73+
// NewCors returns a new CORS config.
74+
func NewCors() kid.MiddlewareFunc {
75+
return NewCorsWithConfig(DefaultCorsConfig)
76+
}
77+
78+
// NewCorsWithConfig returns a new CORS middleware with the given config.
79+
func NewCorsWithConfig(cfg CorsConfig) kid.MiddlewareFunc {
80+
setCorsDefaults(&cfg)
81+
82+
allowedMethods := strings.Join(cfg.AllowedMethods, ", ")
83+
allowedHeaders := strings.Join(cfg.AllowedHeaders, ", ")
84+
exposedHeaders := strings.Join(cfg.ExposedHeaders, ", ")
85+
maxAge := strconv.Itoa(int(cfg.MaxAge.Seconds()))
86+
allowCreds := "false"
87+
if cfg.AllowCredentials {
88+
allowCreds = "true"
89+
}
90+
91+
return func(next kid.HandlerFunc) kid.HandlerFunc {
92+
return func(c *kid.Context) {
93+
req := c.Request()
94+
header := c.Response().Header()
95+
preflight := isPreflight(req)
96+
97+
header.Set("Vary", "Origin")
98+
99+
origin := req.Header.Get("Origin")
100+
if origin == "" {
101+
next(c)
102+
return
103+
}
104+
105+
if !cfg.isAllowedOrigin(c, origin) {
106+
next(c)
107+
return
108+
}
109+
110+
if cfg.allowAllOrigins && !cfg.AllowCredentials {
111+
header.Set("Access-Control-Allow-Origin", "*")
112+
} else {
113+
header.Set("Access-Control-Allow-Origin", origin)
114+
}
115+
116+
if cfg.AllowPrivateNetwork && req.Header.Get("Access-Control-Request-Private-Network") == "true" {
117+
header.Set("Access-Control-Allow-Private-Network", "true")
118+
}
119+
120+
setHeader(header, "Access-Control-Allow-Credentials", allowCreds, "false")
121+
setHeader(header, "Access-Control-Expose-Headers", exposedHeaders, "")
122+
123+
switch preflight {
124+
case false:
125+
next(c)
126+
case true:
127+
setHeader(header, "Access-Control-Allow-Methods", allowedMethods, "")
128+
setHeader(header, "Access-Control-Allow-Headers", allowedHeaders, "")
129+
setHeader(header, "Access-Control-Max-Age", maxAge, "0")
130+
131+
c.NoContent(http.StatusNoContent)
132+
}
133+
}
134+
}
135+
}
136+
137+
// isPreflight checks if this is a preflight request.
138+
func isPreflight(req *http.Request) bool {
139+
return req.Method == http.MethodOptions && req.Header.Get("Access-Control-Request-Method") != ""
140+
}
141+
142+
// isAllowedOrigin validates the origin.
143+
func (cors *CorsConfig) isAllowedOrigin(c *kid.Context, origin string) bool {
144+
if cors.AllowOriginFunc != nil {
145+
return cors.AllowOriginFunc(c, origin)
146+
}
147+
148+
if cors.allowAllOrigins {
149+
return true
150+
}
151+
152+
for _, v := range cors.AllowedOrigins {
153+
if v == "*" {
154+
cors.allowAllOrigins = true
155+
return true
156+
}
157+
if v == origin {
158+
return true
159+
}
160+
}
161+
162+
return false
163+
}
164+
165+
// setHeader sets the header if not empty.
166+
func setHeader(header http.Header, key, value, emptyValue string) {
167+
if value != emptyValue {
168+
header.Set(key, value)
169+
}
170+
}
171+
172+
// setCorsDefaults sets the default CORS configs.
173+
func setCorsDefaults(cfg *CorsConfig) {
174+
if len(cfg.AllowedOrigins) == 0 {
175+
cfg.AllowedOrigins = DefaultCorsConfig.AllowedOrigins
176+
}
177+
178+
if len(cfg.AllowedMethods) == 0 {
179+
cfg.AllowedMethods = DefaultCorsConfig.AllowedMethods
180+
}
181+
}

middlewares/cors_test.go

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
package middlewares
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"strconv"
7+
"testing"
8+
"time"
9+
10+
"github.com/mojixcoder/kid"
11+
"github.com/stretchr/testify/assert"
12+
)
13+
14+
func TestSetHeader(t *testing.T) {
15+
testCases := []struct {
16+
key string
17+
val string
18+
emptyVal string
19+
expectedRes string
20+
}{
21+
{key: "empty_max_age", val: strconv.Itoa(int(time.Duration(0).Seconds())), emptyVal: "0", expectedRes: ""},
22+
{key: "max_age", val: strconv.Itoa(int((time.Hour).Seconds())), emptyVal: "0", expectedRes: "3600"},
23+
{key: "empty_headers", val: "", emptyVal: "", expectedRes: ""},
24+
{key: "headers", val: "headers", emptyVal: "", expectedRes: "headers"},
25+
{key: "empty_creds", val: "false", emptyVal: "false", expectedRes: ""},
26+
{key: "creds", val: "true", emptyVal: "false", expectedRes: "true"},
27+
}
28+
29+
header := make(http.Header)
30+
31+
for _, testCase := range testCases {
32+
t.Run(testCase.key, func(t *testing.T) {
33+
setHeader(header, testCase.key, testCase.val, testCase.emptyVal)
34+
assert.Equal(t, testCase.expectedRes, header.Get(testCase.key))
35+
})
36+
}
37+
}
38+
39+
func TestSetCorsDefaults(t *testing.T) {
40+
cors := &CorsConfig{}
41+
42+
setCorsDefaults(cors)
43+
assert.Equal(t, DefaultCorsConfig.AllowedOrigins, cors.AllowedOrigins)
44+
assert.Equal(t, DefaultCorsConfig.AllowedMethods, cors.AllowedMethods)
45+
46+
cors = &CorsConfig{AllowedMethods: []string{http.MethodConnect}}
47+
48+
setCorsDefaults(cors)
49+
assert.Equal(t, []string{http.MethodConnect}, cors.AllowedMethods)
50+
}
51+
52+
func TestIsPreflight(t *testing.T) {
53+
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
54+
assert.False(t, isPreflight(req))
55+
56+
req.Header.Set("Access-Control-Request-Method", http.MethodGet)
57+
assert.True(t, isPreflight(req))
58+
}
59+
60+
func TestCorsConfig_isAllowedOrigin(t *testing.T) {
61+
cfg := CorsConfig{AllowedOrigins: []string{"http://localhost:2376"}}
62+
63+
assert.True(t, cfg.isAllowedOrigin(nil, "http://localhost:2376"))
64+
assert.False(t, cfg.isAllowedOrigin(nil, "http://localhost:2377"))
65+
assert.False(t, cfg.allowAllOrigins)
66+
67+
cfg.AllowedOrigins = []string{"*"}
68+
69+
assert.True(t, cfg.isAllowedOrigin(nil, "http://localhost:2376"))
70+
assert.True(t, cfg.allowAllOrigins)
71+
assert.True(t, cfg.isAllowedOrigin(nil, "http://localhost:2377"))
72+
73+
cfg.AllowOriginFunc = func(c *kid.Context, origin string) bool {
74+
return false
75+
}
76+
77+
assert.False(t, cfg.isAllowedOrigin(nil, "http://localhost:2376"))
78+
}
79+
80+
func TestNewCors(t *testing.T) {
81+
k := kid.New()
82+
k.Use(NewCors())
83+
84+
res := httptest.NewRecorder()
85+
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
86+
req.Header.Add("Access-Control-Request-Method", http.MethodPost)
87+
req.Header.Add("Origin", "http://localhost:2376")
88+
89+
k.ServeHTTP(res, req)
90+
91+
assert.Equal(t, http.StatusNoContent, res.Code)
92+
assert.Equal(t, "Origin", res.Header().Get("Vary"))
93+
assert.Equal(t, "*", res.Header().Get("Access-Control-Allow-Origin"))
94+
assert.Equal(t, "GET, POST, PUT, PATCH, DELETE, OPTIONS", res.Header().Get("Access-Control-Allow-Methods"))
95+
96+
res = httptest.NewRecorder()
97+
req = httptest.NewRequest(http.MethodPost, "/test", nil)
98+
req.Header.Add("Origin", "http://localhost:2376")
99+
100+
k.ServeHTTP(res, req)
101+
102+
assert.Equal(t, http.StatusNotFound, res.Code)
103+
assert.Equal(t, "Origin", res.Header().Get("Vary"))
104+
assert.Equal(t, "*", res.Header().Get("Access-Control-Allow-Origin"))
105+
assert.Empty(t, res.Header().Get("Access-Control-Allow-Methods"))
106+
}
107+
108+
func TestNewCorsWithConfig(t *testing.T) {
109+
cfg := CorsConfig{
110+
AllowedOrigins: []string{"http://localhost:2376"},
111+
AllowedMethods: []string{http.MethodGet, http.MethodPost},
112+
AllowedHeaders: []string{"Content-Type", "Accept"},
113+
ExposedHeaders: []string{"User-Agent"},
114+
AllowCredentials: true,
115+
AllowPrivateNetwork: true,
116+
MaxAge: 24 * time.Hour,
117+
}
118+
119+
k := kid.New()
120+
k.Use(NewCorsWithConfig(cfg))
121+
122+
res := httptest.NewRecorder()
123+
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
124+
req.Header.Add("Access-Control-Request-Method", http.MethodPost)
125+
req.Header.Add("Origin", "http://localhost:2376")
126+
req.Header.Add("Access-Control-Request-Private-Network", "true")
127+
128+
k.ServeHTTP(res, req)
129+
130+
assert.Equal(t, http.StatusNoContent, res.Code)
131+
assert.Equal(t, "Origin", res.Header().Get("Vary"))
132+
assert.Equal(t, "http://localhost:2376", res.Header().Get("Access-Control-Allow-Origin"))
133+
assert.Equal(t, "GET, POST", res.Header().Get("Access-Control-Allow-Methods"))
134+
assert.Equal(t, "Content-Type, Accept", res.Header().Get("Access-Control-Allow-Headers"))
135+
assert.Equal(t, "User-Agent", res.Header().Get("Access-Control-Expose-Headers"))
136+
assert.Equal(t, "true", res.Header().Get("Access-Control-Allow-Credentials"))
137+
assert.Equal(t, "true", res.Header().Get("Access-Control-Allow-Private-Network"))
138+
assert.Equal(t, "86400", res.Header().Get("Access-Control-Max-Age"))
139+
140+
res = httptest.NewRecorder()
141+
req = httptest.NewRequest(http.MethodOptions, "/test", nil)
142+
req.Header.Add("Access-Control-Request-Method", http.MethodPost)
143+
144+
k.ServeHTTP(res, req)
145+
assert.Equal(t, http.StatusNotFound, res.Code)
146+
assert.Empty(t, res.Header().Get("Access-Control-Allow-Origin"))
147+
148+
res = httptest.NewRecorder()
149+
req = httptest.NewRequest(http.MethodOptions, "/test", nil)
150+
req.Header.Add("Access-Control-Request-Method", http.MethodPost)
151+
req.Header.Add("Origin", "http://localhost:4000")
152+
153+
k.ServeHTTP(res, req)
154+
155+
assert.Equal(t, http.StatusNotFound, res.Code)
156+
assert.Empty(t, res.Header().Get("Access-Control-Allow-Origin"))
157+
}

0 commit comments

Comments
 (0)