Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ jobs:
POSTGRES_PASSWORD: secretpassword
env:
BATON_LOG_LEVEL: debug
BATON_DSN: 'postgres://postgres:secretpassword@localhost:5432/postgres'
CONNECTOR_GRANT: 'grant:entitlement:role:3375:member:role:10'
CONNECTOR_ENTITLEMENT: 'entitlement:role:3375:member'
CONNECTOR_PRINCIPAL: 'role:10'
CONNECTOR_PRINCIPAL_TYPE: 'role'
BATON_DSN: "postgres://postgres:secretpassword@localhost:5432/postgres"
CONNECTOR_GRANT: "grant:entitlement:role:3375:member:role:10"
CONNECTOR_ENTITLEMENT: "entitlement:role:3375:member"
CONNECTOR_PRINCIPAL: "role:10"
CONNECTOR_PRINCIPAL_TYPE: "role"
CONNECTOR_NEW_USER: "testuser"
steps:
- name: Install Go
uses: actions/setup-go@v5
Expand All @@ -63,7 +64,7 @@ jobs:
run: sudo apt install postgresql-client
# - name: Import sql into postgres
# env:
# PGPASSWORD: secretpassword
# PGPASSWORD: secretpassword
# run: psql -h localhost --user postgres -f test/ci.sql
- name: Install baton
run: ./scripts/get-baton.sh && mv baton /usr/local/bin
Expand Down Expand Up @@ -91,7 +92,30 @@ jobs:
run: ./baton-postgresql && baton grants --entitlement "${{ env.CONNECTOR_ENTITLEMENT }}" --output-format=json | jq --exit-status ".grants[].principal.id.resource == \"${{ env.CONNECTOR_PRINCIPAL }}\""

- name: Create user
run: ./baton-postgresql --create-account-login 'testuser'
run: ./baton-postgresql --create-account-login "${{ env.CONNECTOR_NEW_USER }}"

- name: Check user was created
run: ./baton-postgresql && baton resources -o json | jq -e --arg login "${{ env.CONNECTOR_NEW_USER }}" 'any(.resources[].resource.annotations[]?;.["@type"]=="type.googleapis.com/c1.connector.v2.UserTrait" and .login==$login)'

- name: Fetch user id
shell: bash
run: |
set -eub pipefail
NEW_USER_ID=$(baton resources -t role -o json | jq -r --arg login "${{ env.CONNECTOR_NEW_USER }}" '.resources[].resource | select(any(.annotations[]?; .["@type"]=="type.googleapis.com/c1.connector.v2.UserTrait" and .login==$login)) | .id.resource')
echo "NEW_USER_ID=$NEW_USER_ID" >> "$GITHUB_ENV"

- name: Grant role to user
run: ./baton-postgresql --grant-entitlement "${{ env.CONNECTOR_ENTITLEMENT }}" --grant-principal "${{ env.NEW_USER_ID }}" --grant-principal-type "${{ env.CONNECTOR_PRINCIPAL_TYPE }}"

- name: Check role was granted
run: ./baton-postgresql && baton grants --entitlement "${{ env.CONNECTOR_ENTITLEMENT }}" -o json | jq -e --arg login "${{ env.CONNECTOR_NEW_USER }}" 'any(.grants[]?; any(.principal.annotations[]?; .["@type"]=="type.googleapis.com/c1.connector.v2.UserTrait" and .login==$login) or any(.grant.principal.annotations[]?; .["@type"]=="type.googleapis.com/c1.connector.v2.UserTrait" and .login==$login))'

- name: Delete user
run: ./baton-postgresql --delete-resource "${{ env.NEW_USER_ID }}" --delete-resource-type "${{ env.CONNECTOR_PRINCIPAL_TYPE }}"

- name: Check user was deleted
run: ./baton-postgresql && baton resources -o json | jq -e --arg login "${{ env.CONNECTOR_NEW_USER }}" 'any(.resources[].resource.annotations[]?;.["@type"]=="type.googleapis.com/c1.connector.v2.UserTrait" and .login==$login) | not'

# TODO: get correct role id using baton CLI
# - name: Rotate credentials for user
# run: ./baton-postgresql --rotate-credentials 'role:16384' --rotate-credentials-type 'role'
260 changes: 256 additions & 4 deletions pkg/postgres/roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import (
"go.uber.org/zap"
)

var errRevokeGrantsFromRole = errors.New("error revoking grants from role")
var errRevokeParentRolesFromRole = errors.New("error revoking parent roles from role")

type RoleModel struct {
ID int64 `db:"oid"`
Name string `db:"rolname"`
Expand Down Expand Up @@ -157,19 +160,268 @@ func (c *Client) CreateRole(ctx context.Context, roleName string) error {
return err
}

func (c *Client) DeleteRole(ctx context.Context, roleName string) error {
// RoleOwnsObjects checks if a role owns any database objects.
func (c *Client) RoleOwnsObjects(ctx context.Context, roleName string) (bool, error) {
l := ctxzap.Extract(ctx)

query := `
SELECT EXISTS(
SELECT 1 FROM (
-- Check for owned schemas
SELECT 1 FROM pg_namespace WHERE nspowner = (SELECT oid FROM pg_roles WHERE rolname = $1)
UNION ALL
-- Check for owned tables
SELECT 1 FROM pg_class WHERE relowner = (SELECT oid FROM pg_roles WHERE rolname = $1)
UNION ALL
-- Check for owned functions
SELECT 1 FROM pg_proc WHERE proowner = (SELECT oid FROM pg_roles WHERE rolname = $1)
UNION ALL
-- Check for owned sequences
SELECT 1 FROM pg_class WHERE relowner = (SELECT oid FROM pg_roles WHERE rolname = $1) AND relkind = 'S'
UNION ALL
-- Check for owned views
SELECT 1 FROM pg_class WHERE relowner = (SELECT oid FROM pg_roles WHERE rolname = $1) AND relkind = 'v'
UNION ALL
-- Check for owned types
SELECT 1 FROM pg_type WHERE typowner = (SELECT oid FROM pg_roles WHERE rolname = $1)
UNION ALL
-- Check for owned databases
SELECT 1 FROM pg_database WHERE datdba = (SELECT oid FROM pg_roles WHERE rolname = $1)
) owned_objects
)`

var ownsObjects bool
err := c.db.QueryRow(ctx, query, roleName).Scan(&ownsObjects)
if err != nil {
l.Error("error checking if role owns objects", zap.Error(err))
return false, err
}

return ownsObjects, nil
}

// RevokeAllGrantsFromRole revokes all grants from a role across all schemas.
func (c *Client) RevokeAllGrantsFromRole(ctx context.Context, roleName string) error {
l := ctxzap.Extract(ctx)

sanitizedRoleName := pgx.Identifier{roleName}.Sanitize()

schemasQuery := `
SELECT nspname
FROM pg_namespace
WHERE nspname NOT LIKE 'pg_%'
AND nspname != 'information_schema'
ORDER BY nspname`

rows, err := c.db.Query(ctx, schemasQuery)
if err != nil {
l.Error("error querying schemas", zap.Error(err))
return err
}
defer rows.Close()

var schemas []string
for rows.Next() {
var schemaName string
if err := rows.Scan(&schemaName); err != nil {
l.Error("error scanning schema name", zap.Error(err))
return err
}
schemas = append(schemas, schemaName)
}

if err := rows.Err(); err != nil {
l.Error("error iterating schemas", zap.Error(err))
return err
}

var revokeError error
for _, schema := range schemas {
sanitizedSchema := pgx.Identifier{schema}.Sanitize()

revokeTablesQuery := fmt.Sprintf("REVOKE ALL ON ALL TABLES IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
l.Debug("revoking table grants", zap.String("query", revokeTablesQuery))
if _, err := c.db.Exec(ctx, revokeTablesQuery); err != nil {
l.Warn("error revoking table grants", zap.String("schema", schema), zap.Error(err))
revokeError = errors.Join(revokeError, err)
}

revokeSequencesQuery := fmt.Sprintf("REVOKE ALL ON ALL SEQUENCES IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
l.Debug("revoking sequence grants", zap.String("query", revokeSequencesQuery))
if _, err := c.db.Exec(ctx, revokeSequencesQuery); err != nil {
l.Warn("error revoking sequence grants", zap.String("schema", schema), zap.Error(err))
revokeError = errors.Join(revokeError, err)
}

revokeFunctionsQuery := fmt.Sprintf("REVOKE ALL ON ALL FUNCTIONS IN SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
l.Debug("revoking function grants", zap.String("query", revokeFunctionsQuery))
if _, err := c.db.Exec(ctx, revokeFunctionsQuery); err != nil {
l.Warn("error revoking function grants", zap.String("schema", schema), zap.Error(err))
revokeError = errors.Join(revokeError, err)
}

typesQuery := `
SELECT typname
FROM pg_type t
JOIN pg_namespace n ON t.typnamespace = n.oid
WHERE n.nspname = $1
AND t.typtype = 'c'`

typeRows, err := c.db.Query(ctx, typesQuery, schema)
if err != nil {
l.Warn("error querying types", zap.String("schema", schema), zap.Error(err))
revokeError = errors.Join(revokeError, err)
} else {
defer typeRows.Close()

for typeRows.Next() {
var typeName string
if err := typeRows.Scan(&typeName); err != nil {
l.Warn("error scanning type name", zap.String("schema", schema), zap.Error(err))
revokeError = errors.Join(revokeError, err)
continue
}

sanitizedTypeName := pgx.Identifier{schema, typeName}.Sanitize()
revokeTypeQuery := fmt.Sprintf("REVOKE ALL ON TYPE %s FROM %s", sanitizedTypeName, sanitizedRoleName)
l.Debug("revoking type grants", zap.String("query", revokeTypeQuery))
if _, err := c.db.Exec(ctx, revokeTypeQuery); err != nil {
l.Warn("error revoking type grants", zap.String("schema", schema), zap.String("type", typeName), zap.Error(err))
revokeError = errors.Join(revokeError, err)
}
}
}

revokeSchemaQuery := fmt.Sprintf("REVOKE ALL ON SCHEMA %s FROM %s", sanitizedSchema, sanitizedRoleName)
l.Debug("revoking schema grants", zap.String("query", revokeSchemaQuery))
if _, err := c.db.Exec(ctx, revokeSchemaQuery); err != nil {
l.Warn("error revoking schema grants", zap.String("schema", schema), zap.Error(err))
revokeError = errors.Join(revokeError, err)
}
}

revokeDbQuery := fmt.Sprintf("REVOKE ALL ON DATABASE %s FROM %s", pgx.Identifier{c.DatabaseName()}.Sanitize(), sanitizedRoleName)
l.Debug("revoking database grants", zap.String("query", revokeDbQuery))
if _, err := c.db.Exec(ctx, revokeDbQuery); err != nil {
l.Warn("error revoking database grants", zap.Error(err))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to return these errors? Would it be better to join them and return any errors in this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the play is join them, and return them if the final role deletion fails. Will work on that.

revokeError = errors.Join(revokeError, err)
}

if revokeError != nil {
return errors.Join(errRevokeGrantsFromRole, revokeError)
}

return nil
}

// RemoveRoleFromAllRoles removes a role from all other roles.
func (c *Client) RemoveRoleFromAllRoles(ctx context.Context, roleName string) error {
l := ctxzap.Extract(ctx)

sanitizedRoleName := pgx.Identifier{roleName}.Sanitize()

// Get all roles that have this role as a member
query := `
SELECT r.rolname
FROM pg_roles r
JOIN pg_auth_members am ON r.oid = am.roleid
JOIN pg_roles member ON am.member = member.oid
WHERE member.rolname = $1`

rows, err := c.db.Query(ctx, query, roleName)
if err != nil {
l.Error("error querying role memberships", zap.Error(err))
return err
}
defer rows.Close()

var parentRoles []string
for rows.Next() {
var parentRole string
if err := rows.Scan(&parentRole); err != nil {
l.Error("error scanning parent role", zap.Error(err))
return err
}
parentRoles = append(parentRoles, parentRole)
}

if err := rows.Err(); err != nil {
l.Error("error iterating parent roles", zap.Error(err))
return err
}

var revokeError error
// Remove the role from each parent role
for _, parentRole := range parentRoles {
sanitizedParentRole := pgx.Identifier{parentRole}.Sanitize()
revokeQuery := fmt.Sprintf("REVOKE %s FROM %s", sanitizedParentRole, sanitizedRoleName)

l.Debug("removing role from parent role", zap.String("query", revokeQuery))
if _, err := c.db.Exec(ctx, revokeQuery); err != nil {
l.Error("error removing role from parent role", zap.String("parent_role", parentRole), zap.Error(err))
revokeError = errors.Join(revokeError, fmt.Errorf("error removing role from %s role: %w", parentRole, err))
}
}

if revokeError != nil {
return errors.Join(errRevokeParentRolesFromRole, revokeError)
}

return nil
}

// SafeDeleteRole safely deletes a role by first revoking grants and removing memberships.
func (c *Client) SafeDeleteRole(ctx context.Context, roleName string) error {
l := ctxzap.Extract(ctx)

if roleName == "" {
return errors.New("role name cannot be empty")
}

ownsObjects, err := c.RoleOwnsObjects(ctx, roleName)
if err != nil {
l.Error("error checking if role owns objects", zap.Error(err))
return err
}

if ownsObjects {
return fmt.Errorf("cannot delete role '%s': role owns database objects (tables, schemas, functions, etc.). Please transfer ownership or drop objects first", roleName)
}

l.Debug("revoking all grants from role", zap.String("role", roleName))
grantsRevokeError := c.RevokeAllGrantsFromRole(ctx, roleName)
if grantsRevokeError != nil {
l.Error("error revoking grants from role", zap.Error(grantsRevokeError))
if !errors.Is(grantsRevokeError, errRevokeGrantsFromRole) {
return fmt.Errorf("error revoking existing grants from role: %w", grantsRevokeError)
}
}

l.Debug("removing role from all parent roles", zap.String("role", roleName))
roleRevokeError := c.RemoveRoleFromAllRoles(ctx, roleName)
if roleRevokeError != nil {
l.Error("error removing role from parent roles", zap.Error(roleRevokeError))
if !errors.Is(roleRevokeError, errRevokeParentRolesFromRole) {
return fmt.Errorf("error removing role from parent roles: %w", roleRevokeError)
}
}

sanitizedRoleName := pgx.Identifier{roleName}.Sanitize()
query := "DROP ROLE " + sanitizedRoleName
l.Debug("dropping role", zap.String("query", query))
_, err = c.db.Exec(ctx, query)
if err != nil {
l.Error("error dropping role", zap.Error(err))
finalError := errors.Join(err, roleRevokeError, grantsRevokeError)
return fmt.Errorf("error dropping role(%s): %w", roleName, finalError)
}

l.Debug("deleting role", zap.String("query", query))
_, err := c.db.Exec(ctx, query)
return err
l.Info("successfully deleted role", zap.String("role", roleName))
return nil
}

func (c *Client) DeleteRole(ctx context.Context, roleName string) error {
return c.SafeDeleteRole(ctx, roleName)
}

func (c *Client) CreateUser(ctx context.Context, login string, password string) (*RoleModel, error) {
Expand Down
Loading