Skip to content

Commit 5ec8362

Browse files
committed
Check and remove grants and role memberships before removing a role
1 parent 548a3a8 commit 5ec8362

File tree

1 file changed

+205
-4
lines changed

1 file changed

+205
-4
lines changed

pkg/postgres/roles.go

Lines changed: 205 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,19 +157,220 @@ func (c *Client) CreateRole(ctx context.Context, roleName string) error {
157157
return err
158158
}
159159

160-
func (c *Client) DeleteRole(ctx context.Context, roleName string) error {
160+
// RoleOwnsObjects checks if a role owns any database objects.
161+
func (c *Client) RoleOwnsObjects(ctx context.Context, roleName string) (bool, error) {
162+
l := ctxzap.Extract(ctx)
163+
164+
query := `
165+
SELECT EXISTS(
166+
SELECT 1 FROM (
167+
-- Check for owned schemas
168+
SELECT 1 FROM pg_namespace WHERE nspowner = (SELECT oid FROM pg_roles WHERE rolname = $1)
169+
UNION ALL
170+
-- Check for owned tables
171+
SELECT 1 FROM pg_class WHERE relowner = (SELECT oid FROM pg_roles WHERE rolname = $1)
172+
UNION ALL
173+
-- Check for owned functions
174+
SELECT 1 FROM pg_proc WHERE proowner = (SELECT oid FROM pg_roles WHERE rolname = $1)
175+
UNION ALL
176+
-- Check for owned sequences
177+
SELECT 1 FROM pg_class WHERE relowner = (SELECT oid FROM pg_roles WHERE rolname = $1) AND relkind = 'S'
178+
UNION ALL
179+
-- Check for owned views
180+
SELECT 1 FROM pg_class WHERE relowner = (SELECT oid FROM pg_roles WHERE rolname = $1) AND relkind = 'v'
181+
UNION ALL
182+
-- Check for owned types
183+
SELECT 1 FROM pg_type WHERE typowner = (SELECT oid FROM pg_roles WHERE rolname = $1)
184+
UNION ALL
185+
-- Check for owned databases
186+
SELECT 1 FROM pg_database WHERE datdba = (SELECT oid FROM pg_roles WHERE rolname = $1)
187+
) owned_objects
188+
)`
189+
190+
var ownsObjects bool
191+
err := c.db.QueryRow(ctx, query, roleName).Scan(&ownsObjects)
192+
if err != nil {
193+
l.Error("error checking if role owns objects", zap.Error(err))
194+
return false, err
195+
}
196+
197+
return ownsObjects, nil
198+
}
199+
200+
// RevokeAllGrantsFromRole revokes all grants from a role across all schemas.
201+
func (c *Client) RevokeAllGrantsFromRole(ctx context.Context, roleName string) error {
202+
l := ctxzap.Extract(ctx)
203+
204+
sanitizedRoleName := pgx.Identifier{roleName}.Sanitize()
205+
206+
schemasQuery := `
207+
SELECT nspname
208+
FROM pg_namespace
209+
WHERE nspname NOT LIKE 'pg_%'
210+
AND nspname != 'information_schema'
211+
ORDER BY nspname`
212+
213+
rows, err := c.db.Query(ctx, schemasQuery)
214+
if err != nil {
215+
l.Error("error querying schemas", zap.Error(err))
216+
return err
217+
}
218+
defer rows.Close()
219+
220+
var schemas []string
221+
for rows.Next() {
222+
var schemaName string
223+
if err := rows.Scan(&schemaName); err != nil {
224+
l.Error("error scanning schema name", zap.Error(err))
225+
return err
226+
}
227+
schemas = append(schemas, schemaName)
228+
}
229+
230+
if err := rows.Err(); err != nil {
231+
l.Error("error iterating schemas", zap.Error(err))
232+
return err
233+
}
234+
235+
for _, schema := range schemas {
236+
sanitizedSchema := pgx.Identifier{schema}.Sanitize()
237+
238+
revokeTablesQuery := fmt.Sprintf("REVOKE ALL ON ALL TABLES IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
239+
l.Debug("revoking table grants", zap.String("query", revokeTablesQuery))
240+
if _, err := c.db.Exec(ctx, revokeTablesQuery); err != nil {
241+
l.Warn("error revoking table grants", zap.String("schema", schema), zap.Error(err))
242+
}
243+
244+
revokeSequencesQuery := fmt.Sprintf("REVOKE ALL ON ALL SEQUENCES IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
245+
l.Debug("revoking sequence grants", zap.String("query", revokeSequencesQuery))
246+
if _, err := c.db.Exec(ctx, revokeSequencesQuery); err != nil {
247+
l.Warn("error revoking sequence grants", zap.String("schema", schema), zap.Error(err))
248+
}
249+
250+
revokeFunctionsQuery := fmt.Sprintf("REVOKE ALL ON ALL FUNCTIONS IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
251+
l.Debug("revoking function grants", zap.String("query", revokeFunctionsQuery))
252+
if _, err := c.db.Exec(ctx, revokeFunctionsQuery); err != nil {
253+
l.Warn("error revoking function grants", zap.String("schema", schema), zap.Error(err))
254+
}
255+
256+
revokeTypesQuery := fmt.Sprintf("REVOKE ALL ON ALL TYPES IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
257+
l.Debug("revoking type grants", zap.String("query", revokeTypesQuery))
258+
if _, err := c.db.Exec(ctx, revokeTypesQuery); err != nil {
259+
l.Warn("error revoking type grants", zap.String("schema", schema), zap.Error(err))
260+
}
261+
262+
revokeSchemaQuery := fmt.Sprintf("REVOKE ALL ON SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
263+
l.Debug("revoking schema grants", zap.String("query", revokeSchemaQuery))
264+
if _, err := c.db.Exec(ctx, revokeSchemaQuery); err != nil {
265+
l.Warn("error revoking schema grants", zap.String("schema", schema), zap.Error(err))
266+
}
267+
}
268+
269+
revokeDbQuery := fmt.Sprintf("REVOKE ALL ON DATABASE %s FROM %s", pgx.Identifier{c.DatabaseName()}.Sanitize(), sanitizedRoleName)
270+
l.Debug("revoking database grants", zap.String("query", revokeDbQuery))
271+
if _, err := c.db.Exec(ctx, revokeDbQuery); err != nil {
272+
l.Warn("error revoking database grants", zap.Error(err))
273+
}
274+
275+
return nil
276+
}
277+
278+
// RemoveRoleFromAllRoles removes a role from all other roles
279+
func (c *Client) RemoveRoleFromAllRoles(ctx context.Context, roleName string) error {
280+
l := ctxzap.Extract(ctx)
281+
282+
sanitizedRoleName := pgx.Identifier{roleName}.Sanitize()
283+
284+
// Get all roles that have this role as a member
285+
query := `
286+
SELECT r.rolname
287+
FROM pg_roles r
288+
JOIN pg_auth_members am ON r.oid = am.roleid
289+
JOIN pg_roles member ON am.member = member.oid
290+
WHERE member.rolname = $1`
291+
292+
rows, err := c.db.Query(ctx, query, roleName)
293+
if err != nil {
294+
l.Error("error querying role memberships", zap.Error(err))
295+
return err
296+
}
297+
defer rows.Close()
298+
299+
var parentRoles []string
300+
for rows.Next() {
301+
var parentRole string
302+
if err := rows.Scan(&parentRole); err != nil {
303+
l.Error("error scanning parent role", zap.Error(err))
304+
return err
305+
}
306+
parentRoles = append(parentRoles, parentRole)
307+
}
308+
309+
if err := rows.Err(); err != nil {
310+
l.Error("error iterating parent roles", zap.Error(err))
311+
return err
312+
}
313+
314+
// Remove the role from each parent role
315+
for _, parentRole := range parentRoles {
316+
sanitizedParentRole := pgx.Identifier{parentRole}.Sanitize()
317+
revokeQuery := fmt.Sprintf("REVOKE %s FROM %s", sanitizedParentRole, sanitizedRoleName)
318+
319+
l.Debug("removing role from parent role", zap.String("query", revokeQuery))
320+
if _, err := c.db.Exec(ctx, revokeQuery); err != nil {
321+
l.Error("error removing role from parent role", zap.String("parent_role", parentRole), zap.Error(err))
322+
return err
323+
}
324+
}
325+
326+
return nil
327+
}
328+
329+
// SafeDeleteRole safely deletes a role by first revoking grants and removing memberships.
330+
func (c *Client) SafeDeleteRole(ctx context.Context, roleName string) error {
161331
l := ctxzap.Extract(ctx)
162332

163333
if roleName == "" {
164334
return errors.New("role name cannot be empty")
165335
}
166336

337+
ownsObjects, err := c.RoleOwnsObjects(ctx, roleName)
338+
if err != nil {
339+
l.Error("error checking if role owns objects", zap.Error(err))
340+
return err
341+
}
342+
343+
if ownsObjects {
344+
return fmt.Errorf("cannot delete role '%s': role owns database objects (tables, schemas, functions, etc.). Please transfer ownership or drop objects first", roleName)
345+
}
346+
347+
l.Debug("revoking all grants from role", zap.String("role", roleName))
348+
if err := c.RevokeAllGrantsFromRole(ctx, roleName); err != nil {
349+
l.Error("error revoking grants from role", zap.Error(err))
350+
return err
351+
}
352+
353+
l.Debug("removing role from all parent roles", zap.String("role", roleName))
354+
if err := c.RemoveRoleFromAllRoles(ctx, roleName); err != nil {
355+
l.Error("error removing role from parent roles", zap.Error(err))
356+
return err
357+
}
358+
167359
sanitizedRoleName := pgx.Identifier{roleName}.Sanitize()
168360
query := "DROP ROLE " + sanitizedRoleName
361+
l.Debug("dropping role", zap.String("query", query))
362+
_, err = c.db.Exec(ctx, query)
363+
if err != nil {
364+
l.Error("error dropping role", zap.Error(err))
365+
return err
366+
}
169367

170-
l.Debug("deleting role", zap.String("query", query))
171-
_, err := c.db.Exec(ctx, query)
172-
return err
368+
l.Info("successfully deleted role", zap.String("role", roleName))
369+
return nil
370+
}
371+
372+
func (c *Client) DeleteRole(ctx context.Context, roleName string) error {
373+
return c.SafeDeleteRole(ctx, roleName)
173374
}
174375

175376
func (c *Client) CreateUser(ctx context.Context, login string, password string) (*RoleModel, error) {

0 commit comments

Comments
 (0)