Skip to content
209 changes: 205 additions & 4 deletions pkg/postgres/roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,220 @@
return err
}

func (c *Client) DeleteRole(ctx context.Context, roleName string) error {
// RoleOwnsObjects checks if a role owns any database objects.
func (c *Client) RoleOwnsObjects(ctx context.Context, roleName string) (bool, error) {
l := ctxzap.Extract(ctx)

query := `
SELECT EXISTS(
SELECT 1 FROM (
-- Check for owned schemas
SELECT 1 FROM pg_namespace WHERE nspowner = (SELECT oid FROM pg_roles WHERE rolname = $1)
UNION ALL
-- Check for owned tables
SELECT 1 FROM pg_class WHERE relowner = (SELECT oid FROM pg_roles WHERE rolname = $1)
UNION ALL
-- Check for owned functions
SELECT 1 FROM pg_proc WHERE proowner = (SELECT oid FROM pg_roles WHERE rolname = $1)
UNION ALL
-- Check for owned sequences
SELECT 1 FROM pg_class WHERE relowner = (SELECT oid FROM pg_roles WHERE rolname = $1) AND relkind = 'S'
UNION ALL
-- Check for owned views
SELECT 1 FROM pg_class WHERE relowner = (SELECT oid FROM pg_roles WHERE rolname = $1) AND relkind = 'v'
UNION ALL
-- Check for owned types
SELECT 1 FROM pg_type WHERE typowner = (SELECT oid FROM pg_roles WHERE rolname = $1)
UNION ALL
-- Check for owned databases
SELECT 1 FROM pg_database WHERE datdba = (SELECT oid FROM pg_roles WHERE rolname = $1)
) owned_objects
)`

var ownsObjects bool
err := c.db.QueryRow(ctx, query, roleName).Scan(&ownsObjects)
if err != nil {
l.Error("error checking if role owns objects", zap.Error(err))
return false, err
}

return ownsObjects, nil
}

// RevokeAllGrantsFromRole revokes all grants from a role across all schemas.
func (c *Client) RevokeAllGrantsFromRole(ctx context.Context, roleName string) error {
l := ctxzap.Extract(ctx)

sanitizedRoleName := pgx.Identifier{roleName}.Sanitize()

schemasQuery := `
SELECT nspname
FROM pg_namespace
WHERE nspname NOT LIKE 'pg_%'
AND nspname != 'information_schema'
ORDER BY nspname`

rows, err := c.db.Query(ctx, schemasQuery)
if err != nil {
l.Error("error querying schemas", zap.Error(err))
return err
}
defer rows.Close()

var schemas []string
for rows.Next() {
var schemaName string
if err := rows.Scan(&schemaName); err != nil {
l.Error("error scanning schema name", zap.Error(err))
return err
}
schemas = append(schemas, schemaName)
}

if err := rows.Err(); err != nil {
l.Error("error iterating schemas", zap.Error(err))
return err
}

for _, schema := range schemas {
sanitizedSchema := pgx.Identifier{schema}.Sanitize()

revokeTablesQuery := fmt.Sprintf("REVOKE ALL ON ALL TABLES IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
l.Debug("revoking table grants", zap.String("query", revokeTablesQuery))
if _, err := c.db.Exec(ctx, revokeTablesQuery); err != nil {
l.Warn("error revoking table grants", zap.String("schema", schema), zap.Error(err))
}

revokeSequencesQuery := fmt.Sprintf("REVOKE ALL ON ALL SEQUENCES IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
l.Debug("revoking sequence grants", zap.String("query", revokeSequencesQuery))
if _, err := c.db.Exec(ctx, revokeSequencesQuery); err != nil {
l.Warn("error revoking sequence grants", zap.String("schema", schema), zap.Error(err))
}

revokeFunctionsQuery := fmt.Sprintf("REVOKE ALL ON ALL FUNCTIONS IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
l.Debug("revoking function grants", zap.String("query", revokeFunctionsQuery))
if _, err := c.db.Exec(ctx, revokeFunctionsQuery); err != nil {
l.Warn("error revoking function grants", zap.String("schema", schema), zap.Error(err))
}

revokeTypesQuery := fmt.Sprintf("REVOKE ALL ON ALL TYPES IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
l.Debug("revoking type grants", zap.String("query", revokeTypesQuery))
if _, err := c.db.Exec(ctx, revokeTypesQuery); err != nil {
l.Warn("error revoking type grants", zap.String("schema", schema), zap.Error(err))
}

revokeSchemaQuery := fmt.Sprintf("REVOKE ALL ON SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
l.Debug("revoking schema grants", zap.String("query", revokeSchemaQuery))
if _, err := c.db.Exec(ctx, revokeSchemaQuery); err != nil {
l.Warn("error revoking schema grants", zap.String("schema", schema), zap.Error(err))
}
}

revokeDbQuery := fmt.Sprintf("REVOKE ALL ON DATABASE %s FROM %s", pgx.Identifier{c.DatabaseName()}.Sanitize(), sanitizedRoleName)
l.Debug("revoking database grants", zap.String("query", revokeDbQuery))
if _, err := c.db.Exec(ctx, revokeDbQuery); err != nil {
l.Warn("error revoking database grants", zap.Error(err))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to return these errors? Would it be better to join them and return any errors in this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the play is join them, and return them if the final role deletion fails. Will work on that.

}
Comment on lines 244 to 312
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Errors during grant revocation should not be silenced.

Grant revocation errors are logged as warnings but not returned (lines 241, 247, 253, 259, 265, 272). If revocation fails, SafeDeleteRole proceeds to drop the role, which may then fail with unclear error messages. Additionally, partial revocation leaves the database in an inconsistent state.

Consider either:

  1. Returning the first revocation error immediately to halt the process.
  2. Collecting all errors and returning them as a combined error after attempting all revocations.

Apply this diff for approach 1 (fail fast):

 		revokeTablesQuery := fmt.Sprintf("REVOKE ALL ON ALL TABLES IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
 		l.Debug("revoking table grants", zap.String("query", revokeTablesQuery))
 		if _, err := c.db.Exec(ctx, revokeTablesQuery); err != nil {
-			l.Warn("error revoking table grants", zap.String("schema", schema), zap.Error(err))
+			l.Error("error revoking table grants", zap.String("schema", schema), zap.Error(err))
+			return err
 		}
 
 		revokeSequencesQuery := fmt.Sprintf("REVOKE ALL ON ALL SEQUENCES IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
 		l.Debug("revoking sequence grants", zap.String("query", revokeSequencesQuery))
 		if _, err := c.db.Exec(ctx, revokeSequencesQuery); err != nil {
-			l.Warn("error revoking sequence grants", zap.String("schema", schema), zap.Error(err))
+			l.Error("error revoking sequence grants", zap.String("schema", schema), zap.Error(err))
+			return err
 		}
 
 		revokeFunctionsQuery := fmt.Sprintf("REVOKE ALL ON ALL FUNCTIONS IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
 		l.Debug("revoking function grants", zap.String("query", revokeFunctionsQuery))
 		if _, err := c.db.Exec(ctx, revokeFunctionsQuery); err != nil {
-			l.Warn("error revoking function grants", zap.String("schema", schema), zap.Error(err))
+			l.Error("error revoking function grants", zap.String("schema", schema), zap.Error(err))
+			return err
 		}
 
 		revokeTypesQuery := fmt.Sprintf("REVOKE ALL ON ALL TYPES IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
 		l.Debug("revoking type grants", zap.String("query", revokeTypesQuery))
 		if _, err := c.db.Exec(ctx, revokeTypesQuery); err != nil {
-			l.Warn("error revoking type grants", zap.String("schema", schema), zap.Error(err))
+			l.Error("error revoking type grants", zap.String("schema", schema), zap.Error(err))
+			return err
 		}
 
 		revokeSchemaQuery := fmt.Sprintf("REVOKE ALL ON SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
 		l.Debug("revoking schema grants", zap.String("query", revokeSchemaQuery))
 		if _, err := c.db.Exec(ctx, revokeSchemaQuery); err != nil {
-			l.Warn("error revoking schema grants", zap.String("schema", schema), zap.Error(err))
+			l.Error("error revoking schema grants", zap.String("schema", schema), zap.Error(err))
+			return err
 		}
 	}
 
 	revokeDbQuery := fmt.Sprintf("REVOKE ALL ON DATABASE %s FROM %s", pgx.Identifier{c.DatabaseName()}.Sanitize(), sanitizedRoleName)
 	l.Debug("revoking database grants", zap.String("query", revokeDbQuery))
 	if _, err := c.db.Exec(ctx, revokeDbQuery); err != nil {
-		l.Warn("error revoking database grants", zap.Error(err))
+		l.Error("error revoking database grants", zap.Error(err))
+		return err
 	}


return nil
}

// RemoveRoleFromAllRoles removes a role from all other roles

Check failure on line 278 in pkg/postgres/roles.go

View workflow job for this annotation

GitHub Actions / go-lint

Comment should end in a period (godot)
func (c *Client) RemoveRoleFromAllRoles(ctx context.Context, roleName string) error {
l := ctxzap.Extract(ctx)

sanitizedRoleName := pgx.Identifier{roleName}.Sanitize()

// Get all roles that have this role as a member
query := `
SELECT r.rolname
FROM pg_roles r
JOIN pg_auth_members am ON r.oid = am.roleid
JOIN pg_roles member ON am.member = member.oid
WHERE member.rolname = $1`

rows, err := c.db.Query(ctx, query, roleName)
if err != nil {
l.Error("error querying role memberships", zap.Error(err))
return err
}
defer rows.Close()

var parentRoles []string
for rows.Next() {
var parentRole string
if err := rows.Scan(&parentRole); err != nil {
l.Error("error scanning parent role", zap.Error(err))
return err
}
parentRoles = append(parentRoles, parentRole)
}

if err := rows.Err(); err != nil {
l.Error("error iterating parent roles", zap.Error(err))
return err
}

// Remove the role from each parent role
for _, parentRole := range parentRoles {
sanitizedParentRole := pgx.Identifier{parentRole}.Sanitize()
revokeQuery := fmt.Sprintf("REVOKE %s FROM %s", sanitizedParentRole, sanitizedRoleName)

l.Debug("removing role from parent role", zap.String("query", revokeQuery))
if _, err := c.db.Exec(ctx, revokeQuery); err != nil {
l.Error("error removing role from parent role", zap.String("parent_role", parentRole), zap.Error(err))
return err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: I think this error would get bubbled up to the user. Do we want to use fmt.Errorf here to make it clearer what the failure is?

(The same goes for other places where we just return err.)

}
}

return nil
}

// SafeDeleteRole safely deletes a role by first revoking grants and removing memberships.
func (c *Client) SafeDeleteRole(ctx context.Context, roleName string) error {
l := ctxzap.Extract(ctx)

if roleName == "" {
return errors.New("role name cannot be empty")
}

ownsObjects, err := c.RoleOwnsObjects(ctx, roleName)
if err != nil {
l.Error("error checking if role owns objects", zap.Error(err))
return err
}

if ownsObjects {
return fmt.Errorf("cannot delete role '%s': role owns database objects (tables, schemas, functions, etc.). Please transfer ownership or drop objects first", roleName)
}

l.Debug("revoking all grants from role", zap.String("role", roleName))
if err := c.RevokeAllGrantsFromRole(ctx, roleName); err != nil {
l.Error("error revoking grants from role", zap.Error(err))
return err
}

l.Debug("removing role from all parent roles", zap.String("role", roleName))
if err := c.RemoveRoleFromAllRoles(ctx, roleName); err != nil {
l.Error("error removing role from parent roles", zap.Error(err))
return err
}

sanitizedRoleName := pgx.Identifier{roleName}.Sanitize()
query := "DROP ROLE " + sanitizedRoleName
l.Debug("dropping role", zap.String("query", query))
_, err = c.db.Exec(ctx, query)
if err != nil {
l.Error("error dropping role", zap.Error(err))
return err
}

l.Debug("deleting role", zap.String("query", query))
_, err := c.db.Exec(ctx, query)
return err
l.Info("successfully deleted role", zap.String("role", roleName))
return nil
}

func (c *Client) DeleteRole(ctx context.Context, roleName string) error {
return c.SafeDeleteRole(ctx, roleName)
}

func (c *Client) CreateUser(ctx context.Context, login string, password string) (*RoleModel, error) {
Expand Down
Loading