@@ -45,11 +45,11 @@ Advanced example:
45
45
package casbin
46
46
47
47
import (
48
- "net/http"
49
-
48
+ "errors"
50
49
"github.com/casbin/casbin/v2"
51
50
"github.com/labstack/echo/v4"
52
51
"github.com/labstack/echo/v4/middleware"
52
+ "net/http"
53
53
)
54
54
55
55
type (
@@ -59,11 +59,18 @@ type (
59
59
Skipper middleware.Skipper
60
60
61
61
// Enforcer CasbinAuth main rule.
62
- // Required .
62
+ // One of Enforcer or EnforceHandler fields is required .
63
63
Enforcer * casbin.Enforcer
64
64
65
+ // EnforceHandler is custom callback to handle enforcing.
66
+ // One of Enforcer or EnforceHandler fields is required.
67
+ EnforceHandler func (c echo.Context , user string ) (bool , error )
68
+
65
69
// Method to get the username - defaults to using basic auth
66
70
UserGetter func (c echo.Context ) (string , error )
71
+
72
+ // Method to handle errors
73
+ ErrorHandler func (c echo.Context , internal error , proposedStatus int ) error
67
74
}
68
75
)
69
76
75
82
username , _ , _ := c .Request ().BasicAuth ()
76
83
return username , nil
77
84
},
85
+ ErrorHandler : func (c echo.Context , internal error , proposedStatus int ) error {
86
+ err := echo .NewHTTPError (proposedStatus , internal .Error ())
87
+ err .Internal = internal
88
+ return err
89
+ },
78
90
}
79
91
)
80
92
@@ -91,44 +103,42 @@ func Middleware(ce *casbin.Enforcer) echo.MiddlewareFunc {
91
103
// MiddlewareWithConfig returns a CasbinAuth middleware with config.
92
104
// See `Middleware()`.
93
105
func MiddlewareWithConfig (config Config ) echo.MiddlewareFunc {
94
- // Defaults
106
+ if config .Enforcer == nil && config .EnforceHandler == nil {
107
+ panic ("one of casbin middleware Enforcer or EnforceHandler fields must be set" )
108
+ }
95
109
if config .Skipper == nil {
96
110
config .Skipper = DefaultConfig .Skipper
97
111
}
112
+ if config .UserGetter == nil {
113
+ config .UserGetter = DefaultConfig .UserGetter
114
+ }
115
+ if config .ErrorHandler == nil {
116
+ config .ErrorHandler = DefaultConfig .ErrorHandler
117
+ }
118
+ if config .EnforceHandler == nil {
119
+ config .EnforceHandler = func (c echo.Context , user string ) (bool , error ) {
120
+ return config .Enforcer .Enforce (user , c .Request ().URL .Path , c .Request ().Method )
121
+ }
122
+ }
98
123
99
124
return func (next echo.HandlerFunc ) echo.HandlerFunc {
100
125
return func (c echo.Context ) error {
101
126
if config .Skipper (c ) {
102
127
return next (c )
103
128
}
104
129
105
- if pass , err := config .CheckPermission (c ); err == nil && pass {
106
- return next (c )
107
- } else if err != nil {
108
- return echo .NewHTTPError (http .StatusInternalServerError , err .Error ())
130
+ user , err := config .UserGetter (c )
131
+ if err != nil {
132
+ return config .ErrorHandler (c , err , http .StatusForbidden )
109
133
}
110
-
111
- return echo .ErrForbidden
134
+ pass , err := config .EnforceHandler (c , user )
135
+ if err != nil {
136
+ return config .ErrorHandler (c , err , http .StatusInternalServerError )
137
+ }
138
+ if ! pass {
139
+ return config .ErrorHandler (c , errors .New ("enforce did not pass" ), http .StatusForbidden )
140
+ }
141
+ return next (c )
112
142
}
113
143
}
114
144
}
115
-
116
- // GetUserName gets the user name from the request.
117
- // It calls the UserGetter field of the Config struct that allows the caller to customize user identification.
118
- func (a * Config ) GetUserName (c echo.Context ) (string , error ) {
119
- username , err := a .UserGetter (c )
120
- return username , err
121
- }
122
-
123
- // CheckPermission checks the user/method/path combination from the request.
124
- // Returns true (permission granted) or false (permission forbidden)
125
- func (a * Config ) CheckPermission (c echo.Context ) (bool , error ) {
126
- user , err := a .GetUserName (c )
127
- if err != nil {
128
- // Fail safe and do not propagate
129
- return false , nil
130
- }
131
- method := c .Request ().Method
132
- path := c .Request ().URL .Path
133
- return a .Enforcer .Enforce (user , path , method )
134
- }
0 commit comments