|
| 1 | +// Copyright 2024 The Cockroach Authors. |
| 2 | +// |
| 3 | +// Use of this software is governed by the CockroachDB Software License |
| 4 | +// included in the /LICENSE file. |
| 5 | + |
| 6 | +package operations |
| 7 | + |
| 8 | +import ( |
| 9 | + "context" |
| 10 | + "fmt" |
| 11 | + "strings" |
| 12 | + "time" |
| 13 | + |
| 14 | + "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/cluster" |
| 15 | + "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/operation" |
| 16 | + "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/operations/helpers" |
| 17 | + "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/option" |
| 18 | + "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/registry" |
| 19 | + "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/roachtestflags" |
| 20 | + "github.com/cockroachdb/cockroach/pkg/util/randutil" |
| 21 | +) |
| 22 | + |
| 23 | +type cleanupRLSPolicy struct { |
| 24 | + db, table string |
| 25 | + policies []string |
| 26 | + originalRLSStmt string |
| 27 | + locked bool |
| 28 | + waitDuration time.Duration |
| 29 | +} |
| 30 | + |
| 31 | +func (cl *cleanupRLSPolicy) Cleanup(ctx context.Context, o operation.Operation, c cluster.Cluster) { |
| 32 | + o.Status(fmt.Sprintf("Scheduling cleanup to happen after %s", cl.waitDuration)) |
| 33 | + |
| 34 | + // Start a goroutine to handle the wait and cleanup. Since this runs in the |
| 35 | + // background, we also create a new context so that the background goroutine |
| 36 | + // isn't aborted by the parent context. |
| 37 | + go func() { |
| 38 | + newCtx := context.Background() |
| 39 | + |
| 40 | + if deadline, ok := ctx.Deadline(); ok { |
| 41 | + var cancel context.CancelFunc |
| 42 | + newCtx, cancel = context.WithDeadline(newCtx, deadline.Add(cl.waitDuration)) |
| 43 | + defer cancel() |
| 44 | + } |
| 45 | + ctx = newCtx |
| 46 | + |
| 47 | + // Wait for the specified duration before performing cleanup. |
| 48 | + time.Sleep(cl.waitDuration) |
| 49 | + |
| 50 | + conn := c.Conn(ctx, o.L(), 1, option.VirtualClusterName(roachtestflags.VirtualCluster)) |
| 51 | + defer conn.Close() |
| 52 | + |
| 53 | + // Switch to the database where the table is located |
| 54 | + o.Status(fmt.Sprintf("switching to database %s for cleanup", cl.db)) |
| 55 | + if _, err := conn.ExecContext(ctx, fmt.Sprintf("USE %s", cl.db)); err != nil { |
| 56 | + o.Fatal(err) |
| 57 | + } |
| 58 | + |
| 59 | + if cl.locked { |
| 60 | + helpers.SetSchemaLocked(ctx, o, conn, cl.db, cl.table, false /* lock */) |
| 61 | + defer helpers.SetSchemaLocked(ctx, o, conn, cl.db, cl.table, true /* lock */) |
| 62 | + } |
| 63 | + |
| 64 | + // Drop all policies that were created |
| 65 | + for _, policy := range cl.policies { |
| 66 | + o.Status(fmt.Sprintf("dropping policy %s on table %s.%s", policy, cl.db, cl.table)) |
| 67 | + _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP POLICY %s ON %s.%s", policy, cl.db, cl.table)) |
| 68 | + if err != nil { |
| 69 | + o.Fatal(err) |
| 70 | + } |
| 71 | + } |
| 72 | + |
| 73 | + // Restore original RLS state or disable it if it wasn't enabled before |
| 74 | + if cl.originalRLSStmt != "" { |
| 75 | + // Restore the original RLS state |
| 76 | + o.Status(fmt.Sprintf("restoring original row level security state for %s.%s", cl.db, cl.table)) |
| 77 | + if _, err := conn.ExecContext(ctx, cl.originalRLSStmt); err != nil { |
| 78 | + o.Fatal(err) |
| 79 | + } |
| 80 | + } else { |
| 81 | + // If the table didn't have RLS before, disable it |
| 82 | + o.Status(fmt.Sprintf("disabling row level security for %s.%s", cl.db, cl.table)) |
| 83 | + _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s.%s DISABLE ROW LEVEL SECURITY, NO FORCE ROW LEVEL SECURITY", cl.db, cl.table)) |
| 84 | + if err != nil { |
| 85 | + o.Fatal(err) |
| 86 | + } |
| 87 | + } |
| 88 | + }() |
| 89 | +} |
| 90 | + |
| 91 | +func runAddRLSPolicy( |
| 92 | + ctx context.Context, o operation.Operation, c cluster.Cluster, |
| 93 | +) registry.OperationCleanup { |
| 94 | + conn := c.Conn(ctx, o.L(), 1, option.VirtualClusterName(roachtestflags.VirtualCluster)) |
| 95 | + defer func() { _ = conn.Close() }() |
| 96 | + |
| 97 | + rng, _ := randutil.NewPseudoRand() |
| 98 | + |
| 99 | + // Pick a random table |
| 100 | + dbName := helpers.PickRandomDB(ctx, o, conn, helpers.SystemDBs) |
| 101 | + tableName := helpers.PickRandomTable(ctx, o, conn, dbName) |
| 102 | + |
| 103 | + // Check if the table already has RLS enabled and store the original statement if needed |
| 104 | + var tblName, createStmt string |
| 105 | + err := conn.QueryRowContext(ctx, fmt.Sprintf("SHOW CREATE TABLE %s.%s", dbName, tableName)).Scan(&tblName, &createStmt) |
| 106 | + if err != nil { |
| 107 | + o.Fatal(err) |
| 108 | + } |
| 109 | + |
| 110 | + // Look for any RLS statements in the CREATE TABLE |
| 111 | + originalRLSStmt := "" |
| 112 | + |
| 113 | + // Check if RLS is enabled and capture the original statement |
| 114 | + if strings.Contains(createStmt, "ROW LEVEL SECURITY") { |
| 115 | + // Extract the ALTER TABLE statement for RLS |
| 116 | + lines := strings.Split(createStmt, "\n") |
| 117 | + for _, line := range lines { |
| 118 | + if strings.Contains(line, "ROW LEVEL SECURITY") { |
| 119 | + addMissingSemicolon := "" |
| 120 | + if !strings.Contains(line, ";") { |
| 121 | + addMissingSemicolon = ";" |
| 122 | + } |
| 123 | + // Store the line as the original RLS statement and semicolon if missing |
| 124 | + originalRLSStmt = strings.TrimSpace(line + addMissingSemicolon) |
| 125 | + break |
| 126 | + } |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + // If the table's schema is locked, then unlock the table and make sure it will |
| 131 | + // be re-locked during cleanup. |
| 132 | + locked := helpers.IsSchemaLocked(o, conn, dbName, tableName) |
| 133 | + if locked { |
| 134 | + helpers.SetSchemaLocked(ctx, o, conn, dbName, tableName, false /* lock */) |
| 135 | + defer helpers.SetSchemaLocked(ctx, o, conn, dbName, tableName, true /* lock */) |
| 136 | + } |
| 137 | + |
| 138 | + // Enable RLS on the table with random FORCE option |
| 139 | + shouldForceRLS := rng.Intn(2) == 0 // 50% chance of using FORCE |
| 140 | + forceClause := "" |
| 141 | + if shouldForceRLS { |
| 142 | + forceClause = ", FORCE ROW LEVEL SECURITY" |
| 143 | + } |
| 144 | + |
| 145 | + o.Status(fmt.Sprintf("enabling row level security on table %s.%s%s", dbName, tableName, forceClause)) |
| 146 | + _, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s.%s ENABLE ROW LEVEL SECURITY%s", dbName, tableName, forceClause)) |
| 147 | + if err != nil { |
| 148 | + o.Fatal(err) |
| 149 | + } |
| 150 | + |
| 151 | + // Create between 0-5 policies |
| 152 | + numPolicies := rng.Intn(6) |
| 153 | + policies := make([]string, 0, numPolicies) |
| 154 | + |
| 155 | + operations := []string{"ALL", "SELECT", "INSERT", "UPDATE", "DELETE"} |
| 156 | + users := []string{"public", "current_user", "session_user"} |
| 157 | + |
| 158 | + for i := 0; i < numPolicies; i++ { |
| 159 | + // Pick a random operation |
| 160 | + operation := operations[rng.Intn(len(operations))] |
| 161 | + |
| 162 | + // Pick a random user |
| 163 | + user := users[rng.Intn(len(users))] |
| 164 | + |
| 165 | + // Create unique policy name |
| 166 | + policyName := fmt.Sprintf("rls_policy_%s_%d", operation, rng.Uint32()) |
| 167 | + policies = append(policies, policyName) |
| 168 | + |
| 169 | + o.Status(fmt.Sprintf("creating policy %s on table %s.%s for %s to %s", |
| 170 | + policyName, dbName, tableName, user, operation)) |
| 171 | + |
| 172 | + withCheck := "" |
| 173 | + using := "" |
| 174 | + |
| 175 | + // WITH CHECK does is not supported for INSERT and DELETE |
| 176 | + if operation != "SELECT" && operation != "DELETE" { |
| 177 | + // Randomly choose between true or false |
| 178 | + checkExpr := "true" |
| 179 | + if rng.Intn(2) == 0 { |
| 180 | + checkExpr = "false" |
| 181 | + } |
| 182 | + withCheck = fmt.Sprintf("WITH CHECK (%s)", checkExpr) |
| 183 | + } |
| 184 | + |
| 185 | + // USING is not supported for INSERT |
| 186 | + if operation != "INSERT" { |
| 187 | + // Randomly choose between true or false |
| 188 | + usingExpr := "true" |
| 189 | + if rng.Intn(2) == 0 { |
| 190 | + usingExpr = "false" |
| 191 | + } |
| 192 | + using = fmt.Sprintf("USING (%s)", usingExpr) |
| 193 | + } |
| 194 | + |
| 195 | + _, err = conn.ExecContext(ctx, fmt.Sprintf(` |
| 196 | + CREATE POLICY %s ON %s.%s |
| 197 | + FOR %s |
| 198 | + TO %s |
| 199 | + %s |
| 200 | + %s |
| 201 | + `, policyName, dbName, tableName, operation, user, using, withCheck)) |
| 202 | + if err != nil { |
| 203 | + o.Fatal(err) |
| 204 | + } |
| 205 | + } |
| 206 | + |
| 207 | + o.Status(fmt.Sprintf("created %d RLS policies on table %s.%s", numPolicies, dbName, tableName)) |
| 208 | + |
| 209 | + // Return the cleanup struct with wait duration. |
| 210 | + waitTime := time.Hour |
| 211 | + return &cleanupRLSPolicy{ |
| 212 | + db: dbName, |
| 213 | + table: tableName, |
| 214 | + policies: policies, |
| 215 | + originalRLSStmt: originalRLSStmt, |
| 216 | + locked: locked, |
| 217 | + waitDuration: waitTime, |
| 218 | + } |
| 219 | +} |
| 220 | + |
| 221 | +func registerAddRLSPolicy(r registry.Registry) { |
| 222 | + r.AddOperation(registry.OperationSpec{ |
| 223 | + Name: "add-rls-policy", |
| 224 | + Owner: registry.OwnerSQLFoundations, |
| 225 | + Timeout: 30 * time.Minute, |
| 226 | + CompatibleClouds: registry.AllClouds, |
| 227 | + CanRunConcurrently: registry.OperationCanRunConcurrently, |
| 228 | + Dependencies: []registry.OperationDependency{registry.OperationRequiresPopulatedDatabase}, |
| 229 | + Run: runAddRLSPolicy, |
| 230 | + }) |
| 231 | +} |
0 commit comments