Skip to content

Commit a308846

Browse files
authored
fix: AccessControl middleware returns ErrMethodNotFound (#53)
1 parent 493fa69 commit a308846

File tree

3 files changed

+36
-36
lines changed

3 files changed

+36
-36
lines changed

common.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,14 @@ type ProjectStore interface {
6161
type Config[T any] map[string]map[string]T
6262

6363
// Get returns the config value for the given request.
64-
func (c Config[T]) Get(_ context.Context, path string) (v T, err error) {
64+
func (c Config[T]) Get(_ context.Context, path string) (v T, ok bool) {
6565
if c == nil {
66-
return v, fmt.Errorf("config is nil")
66+
return v, false
6767
}
6868

6969
p := strings.Split(path, "/")
7070
if len(p) < 4 {
71-
return v, fmt.Errorf("path has not enough parts: %s", path)
71+
return v, false
7272
}
7373

7474
var (
@@ -78,15 +78,14 @@ func (c Config[T]) Get(_ context.Context, path string) (v T, err error) {
7878
)
7979

8080
if packageName != "rpc" {
81-
return v, fmt.Errorf("path doesn't include rpc: %s", path)
81+
return v, false
8282
}
8383

84-
v, ok := c[serviceName][methodName]
85-
if !ok {
86-
return v, fmt.Errorf("acl not defined for path: %s", path)
84+
if v, ok = c[serviceName][methodName]; !ok {
85+
return v, false
8786
}
8887

89-
return v, nil
88+
return v, true
9089
}
9190

9291
// Verify checks that the given config is valid for the given service.

middleware.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -303,23 +303,25 @@ func AccessControl(acl Config[ACL], cfg Options) func(next http.Handler) http.Ha
303303
return func(next http.Handler) http.Handler {
304304
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
305305
ctx := r.Context()
306-
acl, err := acl.Get(ctx, r.URL.Path)
307-
if err != nil {
308-
cfg.ErrHandler(r, w, proto.ErrUnauthorized.WithCausef("get acl: %w", err))
306+
307+
acl, ok := acl.Get(ctx, r.URL.Path)
308+
if !ok {
309+
// no ACL defined -> delegate to the next handler
310+
next.ServeHTTP(w, r)
309311
return
310312
}
311313

312-
if session, _ := GetSessionType(ctx); !acl.Includes(session) {
313-
err := proto.ErrPermissionDenied
314-
if session == proto.SessionType_Public {
315-
err = proto.ErrUnauthorized
316-
}
317-
318-
cfg.ErrHandler(r, w, err)
314+
session, _ := GetSessionType(ctx)
315+
if acl.Includes(session) {
316+
next.ServeHTTP(w, r)
319317
return
320318
}
321319

322-
next.ServeHTTP(w, r)
320+
err := proto.ErrUnauthorized
321+
if session > proto.SessionType_Public {
322+
err = proto.ErrPermissionDenied
323+
}
324+
cfg.ErrHandler(r, w, err)
323325
})
324326
}
325327
}

middleware_test.go

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -232,25 +232,25 @@ func TestInvalid(t *testing.T) {
232232
assert.True(t, ok)
233233
assert.NoError(t, err)
234234

235-
// Invalid request path with wrong not enough parts in path for valid RPC request
235+
// Invalid request path with wrong not enough parts in path for valid RPC request, this will delegate to next handler and return no error
236236
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
237-
assert.False(t, ok)
238-
assert.ErrorIs(t, err, proto.ErrUnauthorized)
237+
assert.True(t, ok)
238+
assert.NoError(t, err)
239239

240-
// Invalid request path with wrong "rpc"
240+
// Invalid request path with wrong "rpc", this will delegate to next handler and return no error
241241
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/pcr/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
242-
assert.False(t, ok)
243-
assert.ErrorIs(t, err, proto.ErrUnauthorized)
242+
assert.True(t, ok)
243+
assert.NoError(t, err)
244244

245-
// Invalid Service
245+
// Invalid Service, this will delegate to next handler and return no error
246246
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
247-
assert.False(t, ok)
248-
assert.ErrorIs(t, err, proto.ErrUnauthorized)
247+
assert.True(t, ok)
248+
assert.NoError(t, err)
249249

250-
// Invalid Method
250+
// Invalid Method, this will delegate to next handler and return no error
251251
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodNameInvalid), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
252-
assert.False(t, ok)
253-
assert.ErrorIs(t, err, proto.ErrUnauthorized)
252+
assert.True(t, ok)
253+
assert.NoError(t, err)
254254

255255
// Expired JWT Token
256256
claims["exp"] = time.Now().Add(-5 * time.Minute).Unix() // Note: Session() middleware allows some skew.
@@ -283,7 +283,7 @@ func TestCustomErrHandler(t *testing.T) {
283283

284284
ACLConfig := authcontrol.Config[authcontrol.ACL]{
285285
ServiceName: {
286-
MethodName: authcontrol.NewACL(proto.SessionType_Public.OrHigher()...),
286+
MethodName: authcontrol.NewACL(proto.SessionType_AccessKey.OrHigher()...),
287287
},
288288
}
289289

@@ -325,16 +325,15 @@ func TestCustomErrHandler(t *testing.T) {
325325

326326
r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
327327

328-
var claims map[string]any
329-
claims = map[string]any{"service": "client_service"}
328+
claims := map[string]any{"service": "client_service"}
330329

331330
// Valid Request
332331
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
333332
assert.True(t, ok)
334333
assert.NoError(t, err)
335334

336-
// Invalid service which should return custom error from overrided ErrHandler
337-
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceNameInvalid, MethodName), accessKey(AccessKey), jwt(authcontrol.S2SToken(JWTSecret, claims)))
335+
// Invalid Access, should return custom error
336+
ok, err = executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName))
338337
assert.False(t, ok)
339338
assert.ErrorIs(t, err, customErr)
340339
}

0 commit comments

Comments
 (0)