-
Notifications
You must be signed in to change notification settings - Fork 0
Check and remove grants and role memberships before removing a role #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5ec8362
7d39cc9
2659003
63a79a9
104813d
9990ace
657a821
a16d446
d0bf546
7c016d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,9 @@ import ( | |
| "go.uber.org/zap" | ||
| ) | ||
|
|
||
| var errRevokeGrantsFromRole = errors.New("error revoking grants from role") | ||
| var errRevokeParentRolesFromRole = errors.New("error revoking parent roles from role") | ||
|
|
||
| type RoleModel struct { | ||
| ID int64 `db:"oid"` | ||
| Name string `db:"rolname"` | ||
|
|
@@ -157,19 +160,268 @@ func (c *Client) CreateRole(ctx context.Context, roleName string) error { | |
| return err | ||
| } | ||
|
|
||
| func (c *Client) DeleteRole(ctx context.Context, roleName string) error { | ||
| // RoleOwnsObjects checks if a role owns any database objects. | ||
| func (c *Client) RoleOwnsObjects(ctx context.Context, roleName string) (bool, error) { | ||
| l := ctxzap.Extract(ctx) | ||
|
|
||
| query := ` | ||
| SELECT EXISTS( | ||
| SELECT 1 FROM ( | ||
| -- Check for owned schemas | ||
| SELECT 1 FROM pg_namespace WHERE nspowner = (SELECT oid FROM pg_roles WHERE rolname = $1) | ||
| UNION ALL | ||
| -- Check for owned tables | ||
| SELECT 1 FROM pg_class WHERE relowner = (SELECT oid FROM pg_roles WHERE rolname = $1) | ||
| UNION ALL | ||
| -- Check for owned functions | ||
| SELECT 1 FROM pg_proc WHERE proowner = (SELECT oid FROM pg_roles WHERE rolname = $1) | ||
| UNION ALL | ||
| -- Check for owned sequences | ||
| SELECT 1 FROM pg_class WHERE relowner = (SELECT oid FROM pg_roles WHERE rolname = $1) AND relkind = 'S' | ||
| UNION ALL | ||
| -- Check for owned views | ||
| SELECT 1 FROM pg_class WHERE relowner = (SELECT oid FROM pg_roles WHERE rolname = $1) AND relkind = 'v' | ||
| UNION ALL | ||
| -- Check for owned types | ||
| SELECT 1 FROM pg_type WHERE typowner = (SELECT oid FROM pg_roles WHERE rolname = $1) | ||
| UNION ALL | ||
| -- Check for owned databases | ||
| SELECT 1 FROM pg_database WHERE datdba = (SELECT oid FROM pg_roles WHERE rolname = $1) | ||
| ) owned_objects | ||
| )` | ||
|
|
||
| var ownsObjects bool | ||
| err := c.db.QueryRow(ctx, query, roleName).Scan(&ownsObjects) | ||
| if err != nil { | ||
| l.Error("error checking if role owns objects", zap.Error(err)) | ||
| return false, err | ||
| } | ||
|
|
||
| return ownsObjects, nil | ||
| } | ||
|
|
||
| // RevokeAllGrantsFromRole revokes all grants from a role across all schemas. | ||
| func (c *Client) RevokeAllGrantsFromRole(ctx context.Context, roleName string) error { | ||
| l := ctxzap.Extract(ctx) | ||
|
|
||
| sanitizedRoleName := pgx.Identifier{roleName}.Sanitize() | ||
|
|
||
| schemasQuery := ` | ||
| SELECT nspname | ||
| FROM pg_namespace | ||
| WHERE nspname NOT LIKE 'pg_%' | ||
| AND nspname != 'information_schema' | ||
| ORDER BY nspname` | ||
|
|
||
| rows, err := c.db.Query(ctx, schemasQuery) | ||
| if err != nil { | ||
| l.Error("error querying schemas", zap.Error(err)) | ||
| return err | ||
| } | ||
| defer rows.Close() | ||
|
|
||
| var schemas []string | ||
| for rows.Next() { | ||
| var schemaName string | ||
| if err := rows.Scan(&schemaName); err != nil { | ||
| l.Error("error scanning schema name", zap.Error(err)) | ||
| return err | ||
| } | ||
| schemas = append(schemas, schemaName) | ||
| } | ||
|
|
||
| if err := rows.Err(); err != nil { | ||
| l.Error("error iterating schemas", zap.Error(err)) | ||
| return err | ||
| } | ||
|
|
||
| var revokeError error | ||
| for _, schema := range schemas { | ||
| sanitizedSchema := pgx.Identifier{schema}.Sanitize() | ||
|
|
||
| revokeTablesQuery := fmt.Sprintf("REVOKE ALL ON ALL TABLES IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName) | ||
| l.Debug("revoking table grants", zap.String("query", revokeTablesQuery)) | ||
| if _, err := c.db.Exec(ctx, revokeTablesQuery); err != nil { | ||
| l.Warn("error revoking table grants", zap.String("schema", schema), zap.Error(err)) | ||
| revokeError = errors.Join(revokeError, err) | ||
| } | ||
|
|
||
| revokeSequencesQuery := fmt.Sprintf("REVOKE ALL ON ALL SEQUENCES IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName) | ||
| l.Debug("revoking sequence grants", zap.String("query", revokeSequencesQuery)) | ||
| if _, err := c.db.Exec(ctx, revokeSequencesQuery); err != nil { | ||
| l.Warn("error revoking sequence grants", zap.String("schema", schema), zap.Error(err)) | ||
| revokeError = errors.Join(revokeError, err) | ||
| } | ||
|
|
||
| revokeFunctionsQuery := fmt.Sprintf("REVOKE ALL ON ALL FUNCTIONS IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName) | ||
| l.Debug("revoking function grants", zap.String("query", revokeFunctionsQuery)) | ||
| if _, err := c.db.Exec(ctx, revokeFunctionsQuery); err != nil { | ||
| l.Warn("error revoking function grants", zap.String("schema", schema), zap.Error(err)) | ||
| revokeError = errors.Join(revokeError, err) | ||
| } | ||
|
|
||
| typesQuery := ` | ||
| SELECT typname | ||
| FROM pg_type t | ||
| JOIN pg_namespace n ON t.typnamespace = n.oid | ||
| WHERE n.nspname = $1 | ||
| AND t.typtype = 'c'` | ||
|
|
||
| typeRows, err := c.db.Query(ctx, typesQuery, schema) | ||
| if err != nil { | ||
| l.Warn("error querying types", zap.String("schema", schema), zap.Error(err)) | ||
| revokeError = errors.Join(revokeError, err) | ||
| } else { | ||
| defer typeRows.Close() | ||
|
|
||
| for typeRows.Next() { | ||
| var typeName string | ||
| if err := typeRows.Scan(&typeName); err != nil { | ||
| l.Warn("error scanning type name", zap.String("schema", schema), zap.Error(err)) | ||
| revokeError = errors.Join(revokeError, err) | ||
| continue | ||
| } | ||
|
|
||
| sanitizedTypeName := pgx.Identifier{schema, typeName}.Sanitize() | ||
| revokeTypeQuery := fmt.Sprintf("REVOKE ALL ON TYPE %s FROM %s", sanitizedTypeName, sanitizedRoleName) | ||
| l.Debug("revoking type grants", zap.String("query", revokeTypeQuery)) | ||
| if _, err := c.db.Exec(ctx, revokeTypeQuery); err != nil { | ||
| l.Warn("error revoking type grants", zap.String("schema", schema), zap.String("type", typeName), zap.Error(err)) | ||
| revokeError = errors.Join(revokeError, err) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| revokeSchemaQuery := fmt.Sprintf("REVOKE ALL ON SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName) | ||
| l.Debug("revoking schema grants", zap.String("query", revokeSchemaQuery)) | ||
| if _, err := c.db.Exec(ctx, revokeSchemaQuery); err != nil { | ||
| l.Warn("error revoking schema grants", zap.String("schema", schema), zap.Error(err)) | ||
| revokeError = errors.Join(revokeError, err) | ||
| } | ||
| } | ||
|
|
||
| revokeDbQuery := fmt.Sprintf("REVOKE ALL ON DATABASE %s FROM %s", pgx.Identifier{c.DatabaseName()}.Sanitize(), sanitizedRoleName) | ||
| l.Debug("revoking database grants", zap.String("query", revokeDbQuery)) | ||
| if _, err := c.db.Exec(ctx, revokeDbQuery); err != nil { | ||
| l.Warn("error revoking database grants", zap.Error(err)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't want to return these errors? Would it be better to join them and return any errors in this function?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the play is join them, and return them if the final role deletion fails. Will work on that. |
||
| revokeError = errors.Join(revokeError, err) | ||
| } | ||
|
|
||
| if revokeError != nil { | ||
| return errors.Join(errRevokeGrantsFromRole, revokeError) | ||
| } | ||
|
|
||
| return nil | ||
| } | ||
|
|
||
| // RemoveRoleFromAllRoles removes a role from all other roles. | ||
| func (c *Client) RemoveRoleFromAllRoles(ctx context.Context, roleName string) error { | ||
| l := ctxzap.Extract(ctx) | ||
|
|
||
| sanitizedRoleName := pgx.Identifier{roleName}.Sanitize() | ||
|
|
||
| // Get all roles that have this role as a member | ||
| query := ` | ||
| SELECT r.rolname | ||
| FROM pg_roles r | ||
| JOIN pg_auth_members am ON r.oid = am.roleid | ||
| JOIN pg_roles member ON am.member = member.oid | ||
| WHERE member.rolname = $1` | ||
|
|
||
| rows, err := c.db.Query(ctx, query, roleName) | ||
| if err != nil { | ||
| l.Error("error querying role memberships", zap.Error(err)) | ||
| return err | ||
| } | ||
| defer rows.Close() | ||
|
|
||
| var parentRoles []string | ||
| for rows.Next() { | ||
| var parentRole string | ||
| if err := rows.Scan(&parentRole); err != nil { | ||
| l.Error("error scanning parent role", zap.Error(err)) | ||
| return err | ||
| } | ||
| parentRoles = append(parentRoles, parentRole) | ||
| } | ||
|
|
||
| if err := rows.Err(); err != nil { | ||
| l.Error("error iterating parent roles", zap.Error(err)) | ||
| return err | ||
| } | ||
|
|
||
| var revokeError error | ||
| // Remove the role from each parent role | ||
| for _, parentRole := range parentRoles { | ||
| sanitizedParentRole := pgx.Identifier{parentRole}.Sanitize() | ||
| revokeQuery := fmt.Sprintf("REVOKE %s FROM %s", sanitizedParentRole, sanitizedRoleName) | ||
|
|
||
| l.Debug("removing role from parent role", zap.String("query", revokeQuery)) | ||
| if _, err := c.db.Exec(ctx, revokeQuery); err != nil { | ||
| l.Error("error removing role from parent role", zap.String("parent_role", parentRole), zap.Error(err)) | ||
| revokeError = errors.Join(revokeError, fmt.Errorf("error removing role from %s role: %w", parentRole, err)) | ||
| } | ||
| } | ||
|
|
||
| if revokeError != nil { | ||
| return errors.Join(errRevokeParentRolesFromRole, revokeError) | ||
| } | ||
|
|
||
| return nil | ||
| } | ||
|
|
||
| // SafeDeleteRole safely deletes a role by first revoking grants and removing memberships. | ||
| func (c *Client) SafeDeleteRole(ctx context.Context, roleName string) error { | ||
| l := ctxzap.Extract(ctx) | ||
|
|
||
| if roleName == "" { | ||
| return errors.New("role name cannot be empty") | ||
| } | ||
|
|
||
| ownsObjects, err := c.RoleOwnsObjects(ctx, roleName) | ||
| if err != nil { | ||
| l.Error("error checking if role owns objects", zap.Error(err)) | ||
| return err | ||
| } | ||
|
|
||
| if ownsObjects { | ||
| return fmt.Errorf("cannot delete role '%s': role owns database objects (tables, schemas, functions, etc.). Please transfer ownership or drop objects first", roleName) | ||
| } | ||
|
|
||
| l.Debug("revoking all grants from role", zap.String("role", roleName)) | ||
| grantsRevokeError := c.RevokeAllGrantsFromRole(ctx, roleName) | ||
| if grantsRevokeError != nil { | ||
| l.Error("error revoking grants from role", zap.Error(grantsRevokeError)) | ||
| if !errors.Is(grantsRevokeError, errRevokeGrantsFromRole) { | ||
| return fmt.Errorf("error revoking existing grants from role: %w", grantsRevokeError) | ||
| } | ||
| } | ||
|
|
||
| l.Debug("removing role from all parent roles", zap.String("role", roleName)) | ||
| roleRevokeError := c.RemoveRoleFromAllRoles(ctx, roleName) | ||
| if roleRevokeError != nil { | ||
| l.Error("error removing role from parent roles", zap.Error(roleRevokeError)) | ||
| if !errors.Is(roleRevokeError, errRevokeParentRolesFromRole) { | ||
| return fmt.Errorf("error removing role from parent roles: %w", roleRevokeError) | ||
| } | ||
| } | ||
|
|
||
| sanitizedRoleName := pgx.Identifier{roleName}.Sanitize() | ||
| query := "DROP ROLE " + sanitizedRoleName | ||
| l.Debug("dropping role", zap.String("query", query)) | ||
| _, err = c.db.Exec(ctx, query) | ||
| if err != nil { | ||
| l.Error("error dropping role", zap.Error(err)) | ||
| finalError := errors.Join(err, roleRevokeError, grantsRevokeError) | ||
| return fmt.Errorf("error dropping role(%s): %w", roleName, finalError) | ||
| } | ||
|
|
||
| l.Debug("deleting role", zap.String("query", query)) | ||
| _, err := c.db.Exec(ctx, query) | ||
| return err | ||
| l.Info("successfully deleted role", zap.String("role", roleName)) | ||
| return nil | ||
| } | ||
|
|
||
| func (c *Client) DeleteRole(ctx context.Context, roleName string) error { | ||
| return c.SafeDeleteRole(ctx, roleName) | ||
| } | ||
|
|
||
| func (c *Client) CreateUser(ctx context.Context, login string, password string) (*RoleModel, error) { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.