Skip to content

Commit 8f7582e

Browse files
authored
fix: add valid origin check for cors (#83)
Resolves #72
1 parent bdbbe4a commit 8f7582e

File tree

12 files changed

+180
-59
lines changed

12 files changed

+180
-59
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ cmd:
66
clean:
77
rm -rf build
88
test:
9-
cd server && go clean --testcache && go test ./...
9+
cd server && go clean --testcache && go test -v ./...

server/env/env.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,30 @@ func InitEnv() {
8787

8888
allowedOriginsSplit := strings.Split(os.Getenv("ALLOWED_ORIGINS"), ",")
8989
allowedOrigins := []string{}
90+
hasWildCard := false
91+
9092
for _, val := range allowedOriginsSplit {
9193
trimVal := strings.TrimSpace(val)
9294
if trimVal != "" {
93-
allowedOrigins = append(allowedOrigins, trimVal)
95+
if trimVal != "*" {
96+
host, port := utils.GetHostParts(trimVal)
97+
allowedOrigins = append(allowedOrigins, host+":"+port)
98+
} else {
99+
hasWildCard = true
100+
allowedOrigins = append(allowedOrigins, trimVal)
101+
break
102+
}
94103
}
95104
}
105+
106+
if len(allowedOrigins) > 1 && hasWildCard {
107+
allowedOrigins = []string{"*"}
108+
}
109+
96110
if len(allowedOrigins) == 0 {
97111
allowedOrigins = []string{"*"}
98112
}
113+
99114
constants.ALLOWED_ORIGINS = allowedOrigins
100115

101116
if *ARG_AUTHORIZER_URL != "" {

server/handlers/app.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func AppHandler() gin.HandlerFunc {
4949
stateObj.RedirectURL = strings.TrimSuffix(stateObj.RedirectURL, "/")
5050

5151
// validate redirect url with allowed origins
52-
if !utils.IsValidRedirectURL(stateObj.RedirectURL) {
52+
if !utils.IsValidOrigin(stateObj.RedirectURL) {
5353
c.JSON(400, gin.H{"error": "invalid redirect url"})
5454
return
5555
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package integration_test
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
"github.com/authorizerdev/authorizer/server/constants"
9+
"github.com/authorizerdev/authorizer/server/env"
10+
"github.com/authorizerdev/authorizer/server/middlewares"
11+
"github.com/gin-contrib/location"
12+
"github.com/gin-gonic/gin"
13+
"github.com/stretchr/testify/assert"
14+
)
15+
16+
func TestCors(t *testing.T) {
17+
constants.ENV_PATH = "../../.env.local"
18+
env.InitEnv()
19+
r := gin.Default()
20+
r.Use(location.Default())
21+
r.Use(middlewares.GinContextToContextMiddleware())
22+
r.Use(middlewares.CORSMiddleware())
23+
allowedOrigin := "http://localhost:8080" // The allowed origin that you want to check
24+
notAllowedOrigin := "http://myapp.com"
25+
26+
server := httptest.NewServer(r)
27+
defer server.Close()
28+
29+
client := &http.Client{}
30+
req, _ := http.NewRequest(
31+
"GET",
32+
"http://"+server.Listener.Addr().String()+"/api",
33+
nil,
34+
)
35+
req.Header.Add("Origin", allowedOrigin)
36+
37+
get, _ := client.Do(req)
38+
39+
// You should get your origin (or a * depending on your config) if the
40+
// passed origin is allowed.
41+
o := get.Header.Get("Access-Control-Allow-Origin")
42+
assert.NotEqual(t, o, notAllowedOrigin, "Origins should not match")
43+
assert.Equal(t, o, allowedOrigin, "Origins don't match")
44+
}

server/main.go

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,17 @@
11
package main
22

33
import (
4-
"context"
5-
"log"
6-
7-
"github.com/authorizerdev/authorizer/server/constants"
84
"github.com/authorizerdev/authorizer/server/db"
95
"github.com/authorizerdev/authorizer/server/env"
106
"github.com/authorizerdev/authorizer/server/handlers"
7+
"github.com/authorizerdev/authorizer/server/middlewares"
118
"github.com/authorizerdev/authorizer/server/oauth"
129
"github.com/authorizerdev/authorizer/server/session"
1310
"github.com/authorizerdev/authorizer/server/utils"
1411
"github.com/gin-contrib/location"
1512
"github.com/gin-gonic/gin"
1613
)
1714

18-
func GinContextToContextMiddleware() gin.HandlerFunc {
19-
return func(c *gin.Context) {
20-
if constants.AUTHORIZER_URL == "" {
21-
url := location.Get(c)
22-
constants.AUTHORIZER_URL = url.Scheme + "://" + c.Request.Host
23-
log.Println("=> authorizer url:", constants.AUTHORIZER_URL)
24-
}
25-
ctx := context.WithValue(c.Request.Context(), "GinContextKey", c)
26-
c.Request = c.Request.WithContext(ctx)
27-
c.Next()
28-
}
29-
}
30-
31-
// TODO use allowed origins for cors origin
32-
// TODO throw error if url is not allowed
33-
func CORSMiddleware() gin.HandlerFunc {
34-
return func(c *gin.Context) {
35-
origin := c.Request.Header.Get("Origin")
36-
constants.APP_URL = origin
37-
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
38-
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
39-
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
40-
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")
41-
42-
if c.Request.Method == "OPTIONS" {
43-
c.AbortWithStatus(204)
44-
return
45-
}
46-
47-
c.Next()
48-
}
49-
}
50-
5115
func main() {
5216
env.InitEnv()
5317
db.InitDB()
@@ -57,8 +21,8 @@ func main() {
5721

5822
r := gin.Default()
5923
r.Use(location.Default())
60-
r.Use(GinContextToContextMiddleware())
61-
r.Use(CORSMiddleware())
24+
r.Use(middlewares.GinContextToContextMiddleware())
25+
r.Use(middlewares.CORSMiddleware())
6226

6327
r.GET("/", handlers.PlaygroundHandler())
6428
r.POST("/graphql", handlers.GraphqlHandler())

server/middlewares/context.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package middlewares
2+
3+
import (
4+
"context"
5+
"log"
6+
7+
"github.com/authorizerdev/authorizer/server/constants"
8+
"github.com/gin-contrib/location"
9+
"github.com/gin-gonic/gin"
10+
)
11+
12+
func GinContextToContextMiddleware() gin.HandlerFunc {
13+
return func(c *gin.Context) {
14+
if constants.AUTHORIZER_URL == "" {
15+
url := location.Get(c)
16+
constants.AUTHORIZER_URL = url.Scheme + "://" + c.Request.Host
17+
log.Println("=> authorizer url:", constants.AUTHORIZER_URL)
18+
}
19+
ctx := context.WithValue(c.Request.Context(), "GinContextKey", c)
20+
c.Request = c.Request.WithContext(ctx)
21+
c.Next()
22+
}
23+
}

server/middlewares/cors.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package middlewares
2+
3+
import (
4+
"github.com/authorizerdev/authorizer/server/constants"
5+
"github.com/authorizerdev/authorizer/server/utils"
6+
"github.com/gin-gonic/gin"
7+
)
8+
9+
func CORSMiddleware() gin.HandlerFunc {
10+
return func(c *gin.Context) {
11+
origin := c.Request.Header.Get("Origin")
12+
constants.APP_URL = origin
13+
14+
if utils.IsValidOrigin(origin) {
15+
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
16+
}
17+
18+
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
19+
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
20+
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")
21+
22+
if c.Request.Method == "OPTIONS" {
23+
c.AbortWithStatus(204)
24+
return
25+
}
26+
27+
c.Next()
28+
}
29+
}

server/utils/cookie.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
func SetCookie(gc *gin.Context, token string) {
1111
secure := true
1212
httpOnly := true
13-
host := GetHostName(constants.AUTHORIZER_URL)
13+
host, _ := GetHostParts(constants.AUTHORIZER_URL)
1414
domain := GetDomainName(constants.AUTHORIZER_URL)
1515
if domain != "localhost" {
1616
domain = "." + domain
@@ -37,7 +37,7 @@ func DeleteCookie(gc *gin.Context) {
3737
secure := true
3838
httpOnly := true
3939

40-
host := GetDomainName(constants.AUTHORIZER_URL)
40+
host, _ := GetHostParts(constants.AUTHORIZER_URL)
4141
domain := GetDomainName(constants.AUTHORIZER_URL)
4242
if domain != "localhost" {
4343
domain = "." + domain

server/utils/urls.go

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,32 @@ import (
55
"strings"
66
)
77

8-
// GetHostName function to get hostname
9-
func GetHostName(auth_url string) string {
10-
u, err := url.Parse(auth_url)
8+
// GetHostName function returns hostname and port
9+
func GetHostParts(uri string) (string, string) {
10+
tempURI := uri
11+
if !strings.HasPrefix(tempURI, "http") && strings.HasPrefix(tempURI, "https") {
12+
tempURI = "https://" + tempURI
13+
}
14+
15+
u, err := url.Parse(tempURI)
1116
if err != nil {
12-
return `localhost`
17+
return "localhost", "8080"
1318
}
1419

1520
host := u.Hostname()
21+
port := u.Port()
1622

17-
return host
23+
return host, port
1824
}
1925

2026
// GetDomainName function to get domain name
21-
func GetDomainName(auth_url string) string {
22-
u, err := url.Parse(auth_url)
27+
func GetDomainName(uri string) string {
28+
tempURI := uri
29+
if !strings.HasPrefix(tempURI, "http") && strings.HasPrefix(tempURI, "https") {
30+
tempURI = "https://" + tempURI
31+
}
32+
33+
u, err := url.Parse(tempURI)
2334
if err != nil {
2435
return `localhost`
2536
}

server/utils/urls_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ import (
77
)
88

99
func TestGetHostName(t *testing.T) {
10-
authorizer_url := "http://test.herokuapp.com"
10+
authorizer_url := "http://test.herokuapp.com:80"
1111

12-
got := GetHostName(authorizer_url)
13-
want := "test.herokuapp.com"
12+
host, port := GetHostParts(authorizer_url)
13+
expectedHost := "test.herokuapp.com"
1414

15-
assert.Equal(t, got, want, "hostname should be equal")
15+
assert.Equal(t, host, expectedHost, "hostname should be equal")
16+
assert.Equal(t, port, "80", "port should be 80")
1617
}
1718

1819
func TestGetDomainName(t *testing.T) {

0 commit comments

Comments
 (0)