diff --git a/examples/rbac/config.json b/examples/rbac/config.json new file mode 100644 index 000000000..9b3c35929 --- /dev/null +++ b/examples/rbac/config.json @@ -0,0 +1,11 @@ +{ + "route": { + "*": ["admin"], + "/post/*": ["admin","editor"], + "/dashboard": ["admin","editor"], + "/profile": ["admin","editor","user"], + "/home":["admin","editor","user"], + "/sayhello/*":["admin","editor","user"], + "/greet":["admin","editor","user"] + } + } \ No newline at end of file diff --git a/examples/rbac/main.go b/examples/rbac/main.go new file mode 100644 index 000000000..6b67f1f8f --- /dev/null +++ b/examples/rbac/main.go @@ -0,0 +1,42 @@ +package main + +import ( + "net/http" + + "gofr.dev/pkg/gofr" + "gofr.dev/pkg/gofr/rbac" +) + +func main() { + app := gofr.New() + + // loading the rbac config file which is required + rbacConfigs, err := rbac.LoadPermissions("config.json") + if err != nil { + return + } + + // example of setting override for a specific role + overrides := map[string]bool{"/greet": true} + rbacConfigs.OverRides = overrides + + // setting the role extractor function + rbacConfigs.RoleExtractorFunc = extractor + + // applying the middleware + app.UseMiddleware(rbac.Middleware(rbacConfigs)) + + // sample routes + app.GET("/sayhello/321", handler) + app.GET("/greet", rbac.RequireRole("user1", handler)) + + app.Run() // listens and serves on localhost:8000 +} + +func extractor(req *http.Request, _ ...any) (string, error) { + return req.Header.Get("X-USER-ROLE"), nil +} + +func handler(ctx *gofr.Context) (any, error) { + return "Hello World!", nil +} diff --git a/pkg/gofr/rbac/config.go b/pkg/gofr/rbac/config.go new file mode 100644 index 000000000..8ea962e28 --- /dev/null +++ b/pkg/gofr/rbac/config.go @@ -0,0 +1,28 @@ +package rbac + +import ( + "encoding/json" + "net/http" + "os" +) + +type Config struct { + RouteWithPermissions map[string][]string `json:"route"` // route: [Allowed roles] + RoleExtractorFunc func(req *http.Request, args ...any) (string, error) + OverRides map[string]bool // route: [override bool] +} + +func LoadPermissions(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var config Config + + if err := json.Unmarshal(data, &config); err != nil { + return nil, err + } + + return &config, nil +} diff --git a/pkg/gofr/rbac/config_test.go b/pkg/gofr/rbac/config_test.go new file mode 100644 index 000000000..19ed959ea --- /dev/null +++ b/pkg/gofr/rbac/config_test.go @@ -0,0 +1,47 @@ +package rbac + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLoadPermissions_Success(t *testing.T) { + jsonContent := `{ + "route": {"admin":["read", "write"], "user":["read"]}, + "OverRides": {"admin":true, "user":false} + }` + tempFile, err := os.CreateTemp("", "test_permissions_*.json") + assert.NoError(t, err) + defer os.Remove(tempFile.Name()) + + _, err = tempFile.Write([]byte(jsonContent)) + assert.NoError(t, err) + tempFile.Close() + + cfg, err := LoadPermissions(tempFile.Name()) + assert.NoError(t, err) + assert.Equal(t, map[string][]string{"admin": {"read", "write"}, "user": {"read"}}, cfg.RouteWithPermissions) + assert.Equal(t, map[string]bool{"admin": true, "user": false}, cfg.OverRides) +} + +func TestLoadPermissions_FileNotFound(t *testing.T) { + cfg, err := LoadPermissions("non_existent_file.json") + assert.Nil(t, cfg) + assert.Error(t, err) +} + +func TestLoadPermissions_InvalidJSON(t *testing.T) { + tempFile, err := os.CreateTemp("", "badjson_*.json") + assert.NoError(t, err) + defer os.Remove(tempFile.Name()) + + _, err = tempFile.Write([]byte(`{"route": [INVALID JSON}`)) + assert.NoError(t, err) + tempFile.Close() + + cfg, err := LoadPermissions(tempFile.Name()) + assert.Nil(t, cfg) + assert.Error(t, err) +} diff --git a/pkg/gofr/rbac/helper.go b/pkg/gofr/rbac/helper.go new file mode 100644 index 000000000..5ad1911d6 --- /dev/null +++ b/pkg/gofr/rbac/helper.go @@ -0,0 +1,13 @@ +package rbac + +import "gofr.dev/pkg/gofr" + +func HasRole(ctx *gofr.Context, role string) bool { + expRole, _ := ctx.Context.Value(userRole).(string) + return expRole == role +} + +func GetUserRole(ctx *gofr.Context) string { + role, _ := ctx.Context.Value(userRole).(string) + return role +} diff --git a/pkg/gofr/rbac/helper_test.go b/pkg/gofr/rbac/helper_test.go new file mode 100644 index 000000000..5c0ced9fd --- /dev/null +++ b/pkg/gofr/rbac/helper_test.go @@ -0,0 +1,52 @@ +package rbac + +import ( + "context" + "testing" + + "gofr.dev/pkg/gofr" +) + +func TestHasRole(t *testing.T) { + tests := []struct { + name string + ctxRoleVal string + checkRole string + expectedRes bool + }{ + {"matching role", "admin", "admin", true}, + {"non-matching role", "viewer", "admin", false}, + {"empty role in context", "", "admin", false}, + {"nil role in context", "", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create base context with the userRole value + baseCtx := context.WithValue(t.Context(), userRole, tt.ctxRoleVal) + + // Wrap baseCtx in gofr.Context + gofrCtx := &gofr.Context{Context: baseCtx} + + got := HasRole(gofrCtx, tt.checkRole) + if got != tt.expectedRes { + t.Errorf("HasRole() = %v, want %v", got, tt.expectedRes) + } + }) + } +} +func TestGetUserRole(t *testing.T) { + expectedRole := "editor" + baseCtx := context.WithValue(t.Context(), userRole, expectedRole) + gofrCtx := &gofr.Context{Context: baseCtx} + + if role := GetUserRole(gofrCtx); role != expectedRole { + t.Errorf("GetUserRole() = %v, want %v", role, expectedRole) + } + + // Test no role set should return "" + emptyCtx := &gofr.Context{Context: t.Context()} + if role := GetUserRole(emptyCtx); role != "" { + t.Errorf("GetUserRole() with no role = %v, want empty string", role) + } +} diff --git a/pkg/gofr/rbac/match.go b/pkg/gofr/rbac/match.go new file mode 100644 index 000000000..1b36accf5 --- /dev/null +++ b/pkg/gofr/rbac/match.go @@ -0,0 +1,33 @@ +package rbac + +import ( + "path" +) + +func isRoleAllowed(role, apiroute string, config *Config) bool { + var routePermissions []string + + // find the matched route from config + for route, allowedRoles := range config.RouteWithPermissions { + if isMatched, _ := path.Match(route, apiroute); isMatched && route != "" { + // check if override is set for the matched route + if config.OverRides[apiroute] { + return true + } + routePermissions = allowedRoles + break + } + } + + // append global permissions if any + routePermissions = append(routePermissions, config.RouteWithPermissions["*"]...) + + // check if role is in allowed roles for the matched route + for _, allowedRole := range routePermissions { + if allowedRole == role || allowedRole == "*" { + return true + } + } + + return false +} diff --git a/pkg/gofr/rbac/match_test.go b/pkg/gofr/rbac/match_test.go new file mode 100644 index 000000000..5e1196af5 --- /dev/null +++ b/pkg/gofr/rbac/match_test.go @@ -0,0 +1,42 @@ +package rbac + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsRoleAllowed(t *testing.T) { + config := &Config{ + RouteWithPermissions: map[string][]string{ + "/admin/*": {"admin"}, + "/user/*": {"user", "admin"}, + "*": {"guest"}, + }, + OverRides: map[string]bool{ + "/admin/home": true, + }, + } + + tests := []struct { + name string + role string + route string + expected bool + }{ + {"Override true", "anyone", "/admin/home", true}, + {"Pattern match /admin/*", "admin", "/admin/dashboard", true}, + {"Pattern match negative", "user", "/admin/dashboard", false}, + {"Non-pattern route", "user", "/user/profile", true}, + {"Wildcard permission", "guest", "/anything", true}, + {"No route or global match", "unknown", "/private", false}, + {"Not matched or globally allowed", "nobody", "/wildcard", false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := isRoleAllowed(tc.role, tc.route, config) + assert.Equal(t, tc.expected, got, tc.name) + }) + } +} diff --git a/pkg/gofr/rbac/middleware.go b/pkg/gofr/rbac/middleware.go new file mode 100644 index 000000000..0a0bb4a0e --- /dev/null +++ b/pkg/gofr/rbac/middleware.go @@ -0,0 +1,50 @@ +package rbac + +import ( + "context" + "errors" + "net/http" + + "gofr.dev/pkg/gofr" +) + +type authMethod int + +const userRole authMethod = 4 + +var ErrAccessDenied = errors.New("forbidden: access denied") + +func Middleware(config *Config, args ...any) func(handler http.Handler) http.Handler { + return func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + role, err := config.RoleExtractorFunc(r, args) + if err != nil { + http.Error(w, "Unauthorized: Missing or invalid role", http.StatusUnauthorized) + + return + } + + if !isRoleAllowed(role, r.URL.Path, config) { + http.Error(w, "Forbidden: Access denied", http.StatusForbidden) + + return + } + + ctx := context.WithValue(r.Context(), userRole, role) + + handler.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +func RequireRole(allowedRole string, handlerFunc gofr.Handler) gofr.Handler { + return func(ctx *gofr.Context) (any, error) { + role, _ := ctx.Context.Value(userRole).(string) + + if role == allowedRole { + return handlerFunc(ctx) + } + + return nil, ErrAccessDenied + } +} diff --git a/pkg/gofr/rbac/middleware_test.go b/pkg/gofr/rbac/middleware_test.go new file mode 100644 index 000000000..adbedab5e --- /dev/null +++ b/pkg/gofr/rbac/middleware_test.go @@ -0,0 +1,144 @@ +package rbac + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "gofr.dev/pkg/gofr" +) + +// mock role extractor function for testing +func mockRoleExtractor(r *http.Request, args ...any) (string, error) { + role := r.Header.Get("Role") + if role == "" { + return "", errors.New("no role") + } + return role, nil +} + +func TestMiddleware_Authorization(t *testing.T) { + config := &Config{ + RouteWithPermissions: map[string][]string{ + "/allowed": {"admin"}, + }, + OverRides: map[string]bool{}, + RoleExtractorFunc: mockRoleExtractor, + } + + // next handler to confirm request passed through middleware + nextCalled := false + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + middleware := Middleware(config) + + // test cases + tests := []struct { + name string + roleHeader string + requestPath string + wantStatus int + wantNextCall bool + }{ + { + name: "No role header", + roleHeader: "", + requestPath: "/allowed", + wantStatus: http.StatusUnauthorized, + wantNextCall: false, + }, + { + name: "Unauthorized role", + roleHeader: "user", + requestPath: "/allowed", + wantStatus: http.StatusForbidden, + wantNextCall: false, + }, + { + name: "Authorized role", + roleHeader: "admin", + requestPath: "/allowed", + wantStatus: http.StatusOK, + wantNextCall: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + nextCalled = false + req := httptest.NewRequest(http.MethodGet, tc.requestPath, nil) + if tc.roleHeader != "" { + req.Header.Set("Role", tc.roleHeader) + } + w := httptest.NewRecorder() + + handlerToTest := middleware(nextHandler) + handlerToTest.ServeHTTP(w, req) + + assert.Equal(t, tc.wantStatus, w.Code) + assert.Equal(t, tc.wantNextCall, nextCalled) + }) + } +} + +func TestRequireRole_Handler(t *testing.T) { + allowedRole := "admin" + called := false + handlerFunc := func(ctx *gofr.Context) (any, error) { + called = true + return "success", nil + } + + wrappedHandler := RequireRole(allowedRole, handlerFunc) + + tests := []struct { + name string + contextRole string + wantErr error + wantCalled bool + }{ + { + name: "Role allowed", + contextRole: "admin", + wantErr: nil, + wantCalled: true, + }, + { + name: "Role denied", + contextRole: "user", + wantErr: ErrAccessDenied, + wantCalled: false, + }, + { + name: "No role in context", + contextRole: "", + wantErr: ErrAccessDenied, + wantCalled: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + called = false + ctx := &gofr.Context{ + Context: context.WithValue(context.Background(), userRole, tc.contextRole), + } + resp, err := wrappedHandler(ctx) + + assert.Equal(t, tc.wantErr, err) + if tc.wantCalled { + assert.True(t, called) + assert.Equal(t, "success", resp) + } else { + assert.False(t, called) + assert.Nil(t, resp) + } + }) + } +}