Skip to content

Commit 186cf2b

Browse files
committed
casbin: add EnforceHandler to allow custom callback to handle enforcing.
1 parent 4d116ee commit 186cf2b

File tree

2 files changed

+66
-30
lines changed

2 files changed

+66
-30
lines changed

casbin/casbin.go

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ Advanced example:
4545
package casbin
4646

4747
import (
48-
"net/http"
49-
48+
"errors"
5049
"github.com/casbin/casbin/v2"
5150
"github.com/labstack/echo/v4"
5251
"github.com/labstack/echo/v4/middleware"
52+
"net/http"
5353
)
5454

5555
type (
@@ -59,11 +59,18 @@ type (
5959
Skipper middleware.Skipper
6060

6161
// Enforcer CasbinAuth main rule.
62-
// Required.
62+
// One of Enforcer or EnforceHandler fields is required.
6363
Enforcer *casbin.Enforcer
6464

65+
// EnforceHandler is custom callback to handle enforcing.
66+
// One of Enforcer or EnforceHandler fields is required.
67+
EnforceHandler func(c echo.Context, user string) (bool, error)
68+
6569
// Method to get the username - defaults to using basic auth
6670
UserGetter func(c echo.Context) (string, error)
71+
72+
// Method to handle errors
73+
ErrorHandler func(c echo.Context, internal error, proposedStatus int) error
6774
}
6875
)
6976

@@ -75,6 +82,11 @@ var (
7582
username, _, _ := c.Request().BasicAuth()
7683
return username, nil
7784
},
85+
ErrorHandler: func(c echo.Context, internal error, proposedStatus int) error {
86+
err := echo.NewHTTPError(proposedStatus, internal.Error())
87+
err.Internal = internal
88+
return err
89+
},
7890
}
7991
)
8092

@@ -91,44 +103,42 @@ func Middleware(ce *casbin.Enforcer) echo.MiddlewareFunc {
91103
// MiddlewareWithConfig returns a CasbinAuth middleware with config.
92104
// See `Middleware()`.
93105
func MiddlewareWithConfig(config Config) echo.MiddlewareFunc {
94-
// Defaults
106+
if config.Enforcer == nil && config.EnforceHandler == nil {
107+
panic("one of casbin middleware Enforcer or EnforceHandler fields must be set")
108+
}
95109
if config.Skipper == nil {
96110
config.Skipper = DefaultConfig.Skipper
97111
}
112+
if config.UserGetter == nil {
113+
config.UserGetter = DefaultConfig.UserGetter
114+
}
115+
if config.ErrorHandler == nil {
116+
config.ErrorHandler = DefaultConfig.ErrorHandler
117+
}
118+
if config.EnforceHandler == nil {
119+
config.EnforceHandler = func(c echo.Context, user string) (bool, error) {
120+
return config.Enforcer.Enforce(user, c.Request().URL.Path, c.Request().Method)
121+
}
122+
}
98123

99124
return func(next echo.HandlerFunc) echo.HandlerFunc {
100125
return func(c echo.Context) error {
101126
if config.Skipper(c) {
102127
return next(c)
103128
}
104129

105-
if pass, err := config.CheckPermission(c); err == nil && pass {
106-
return next(c)
107-
} else if err != nil {
108-
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
130+
user, err := config.UserGetter(c)
131+
if err != nil {
132+
return config.ErrorHandler(c, err, http.StatusForbidden)
109133
}
110-
111-
return echo.ErrForbidden
134+
pass, err := config.EnforceHandler(c, user)
135+
if err != nil {
136+
return config.ErrorHandler(c, err, http.StatusInternalServerError)
137+
}
138+
if !pass {
139+
return config.ErrorHandler(c, errors.New("enforce did not pass"), http.StatusForbidden)
140+
}
141+
return next(c)
112142
}
113143
}
114144
}
115-
116-
// GetUserName gets the user name from the request.
117-
// It calls the UserGetter field of the Config struct that allows the caller to customize user identification.
118-
func (a *Config) GetUserName(c echo.Context) (string, error) {
119-
username, err := a.UserGetter(c)
120-
return username, err
121-
}
122-
123-
// CheckPermission checks the user/method/path combination from the request.
124-
// Returns true (permission granted) or false (permission forbidden)
125-
func (a *Config) CheckPermission(c echo.Context) (bool, error) {
126-
user, err := a.GetUserName(c)
127-
if err != nil {
128-
// Fail safe and do not propagate
129-
return false, nil
130-
}
131-
method := c.Request().Method
132-
path := c.Request().URL.Path
133-
return a.Enforcer.Enforce(user, path, method)
134-
}

casbin/casbin_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ package casbin
22

33
import (
44
"errors"
5+
"github.com/stretchr/testify/assert"
56
"net/http"
67
"net/http/httptest"
8+
"strings"
79
"testing"
810

911
"github.com/casbin/casbin/v2"
@@ -131,3 +133,27 @@ func TestUserGetterError(t *testing.T) {
131133
})
132134
testRequest(t, h, "cathy", "/dataset1/item", "GET", 403)
133135
}
136+
137+
func TestCustomEnforceHandler(t *testing.T) {
138+
ce, err := casbin.NewEnforcer("auth_model.conf", "auth_policy.csv")
139+
assert.NoError(t, err)
140+
141+
_, err = ce.AddPolicy("bob", "/user/bob", "PATCH_SELF")
142+
assert.NoError(t, err)
143+
144+
cnf := Config{
145+
EnforceHandler: func(c echo.Context, user string) (bool, error) {
146+
method := c.Request().Method
147+
if strings.HasPrefix(c.Request().URL.Path, "/user/bob") {
148+
method += "_SELF"
149+
}
150+
return ce.Enforce(user, c.Request().URL.Path, method)
151+
},
152+
}
153+
h := MiddlewareWithConfig(cnf)(func(c echo.Context) error {
154+
return c.String(http.StatusOK, "test")
155+
})
156+
testRequest(t, h, "bob", "/dataset2/resource1", "GET", http.StatusOK)
157+
testRequest(t, h, "bob", "/user/alice", "PATCH", http.StatusForbidden)
158+
testRequest(t, h, "bob", "/user/bob", "PATCH", http.StatusOK)
159+
}

0 commit comments

Comments
 (0)