Skip to content

Commit 3bc72aa

Browse files
craig[bot]Dedej-Bergin
andcommitted
Merge #144685
144685: workload: add support for RLS policies in RSW workload r=Dedej-Bergin a=Dedej-Bergin This change adds CREATE POLICY and DROP POLICY operations to the Random Schema Changer workload for testing row-level security (RLS) policy functionality in the declarative schema changer. Informs: #137120 Epic: CRDB-11724 Release note: none Co-authored-by: Bergin Dedej <[email protected]>
2 parents c324bab + befea5e commit 3bc72aa

File tree

3 files changed

+289
-9
lines changed

3 files changed

+289
-9
lines changed

pkg/workload/schemachange/operation_generator.go

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,6 +1655,14 @@ func (og *operationGenerator) dropColumn(ctx context.Context, tx pgx.Tx) (*opStm
16551655
return nil, err
16561656
}
16571657

1658+
// Check if the table has any policies
1659+
tableHasPolicies := false
1660+
if tableExists {
1661+
if tableHasPolicies, err = og.tableHasPolicies(ctx, tx, tableName); err != nil {
1662+
return nil, err
1663+
}
1664+
}
1665+
16581666
columnName, err := og.randColumn(ctx, tx, *tableName, og.pctExisting(true))
16591667
if err != nil {
16601668
return nil, err
@@ -1703,11 +1711,40 @@ func (og *operationGenerator) dropColumn(ctx context.Context, tx pgx.Tx) (*opStm
17031711
// It is possible the column we are dropping is in the new primary key,
17041712
// so a potential error is an invalid reference in this case.
17051713
{code: pgcode.InvalidColumnReference, condition: og.useDeclarativeSchemaChanger && hasAlterPKSchemaChange},
1714+
// It is possible that we cannot drop column because
1715+
// it is referenced in a policy expression.
1716+
{code: pgcode.InvalidTableDefinition, condition: tableHasPolicies},
17061717
})
17071718
stmt.sql = fmt.Sprintf(`ALTER TABLE %s DROP COLUMN %s`, tableName.String(), columnName.String())
17081719
return stmt, nil
17091720
}
17101721

1722+
// tableHasPolicies checks if a table has any row-level security policies defined
1723+
func (og *operationGenerator) tableHasPolicies(
1724+
ctx context.Context, tx pgx.Tx, tableName *tree.TableName,
1725+
) (bool, error) {
1726+
// Query to check if a table has any RLS policies
1727+
query := `
1728+
SELECT EXISTS (
1729+
SELECT 1
1730+
FROM pg_policy p
1731+
JOIN pg_class t ON p.polrelid = t.oid
1732+
JOIN pg_namespace n ON t.relnamespace = n.oid
1733+
WHERE t.relname = $1
1734+
AND n.nspname = $2
1735+
LIMIT 1
1736+
)
1737+
`
1738+
1739+
var hasPolicies bool
1740+
err := tx.QueryRow(ctx, query, tableName.Object(), tableName.Schema()).Scan(&hasPolicies)
1741+
if err != nil {
1742+
return false, err
1743+
}
1744+
1745+
return hasPolicies, nil
1746+
}
1747+
17111748
func (og *operationGenerator) dropColumnDefault(ctx context.Context, tx pgx.Tx) (*opStmt, error) {
17121749
tableName, err := og.randTable(ctx, tx, og.pctExisting(true), "")
17131750
if err != nil {
@@ -4905,3 +4942,232 @@ func (og *operationGenerator) alterTableRLS(ctx context.Context, tx pgx.Tx) (*op
49054942
opStmt.sql = sqlStatement
49064943
return opStmt, nil
49074944
}
4945+
4946+
// generatePolicyExpression creates a random expression suitable for use in policy USING or WITH CHECK clauses
4947+
func (og *operationGenerator) generatePolicyExpression(
4948+
columns []string, preferredColumn int,
4949+
) string {
4950+
if len(columns) == 0 {
4951+
// Fallback to simple boolean if no columns are available
4952+
if og.randIntn(2) == 0 {
4953+
return "true"
4954+
}
4955+
return "false"
4956+
}
4957+
4958+
// Choose a column index (or use the preferred one if valid)
4959+
colIdx := og.randIntn(len(columns))
4960+
if preferredColumn >= 0 && preferredColumn < len(columns) {
4961+
colIdx = preferredColumn
4962+
}
4963+
4964+
// Generate a basic expression (IS NULL or IS NOT NULL)
4965+
var expr string
4966+
if og.randIntn(2) == 0 {
4967+
// IS NULL check
4968+
expr = fmt.Sprintf("%s IS NULL", columns[colIdx])
4969+
} else {
4970+
// IS NOT NULL check
4971+
expr = fmt.Sprintf("%s IS NOT NULL", columns[colIdx])
4972+
}
4973+
4974+
// Sometimes add complexity to the expression
4975+
if og.randIntn(3) == 0 {
4976+
expressionType := og.randIntn(3)
4977+
switch expressionType {
4978+
case 0:
4979+
// Add OR TRUE/FALSE
4980+
if og.randIntn(2) == 0 {
4981+
expr = fmt.Sprintf("(%s OR TRUE)", expr)
4982+
} else {
4983+
expr = fmt.Sprintf("(%s OR FALSE)", expr)
4984+
}
4985+
case 1:
4986+
// Use a different column if available for a compound expression
4987+
if len(columns) > 1 {
4988+
secondColIdx := (colIdx + 1) % len(columns)
4989+
if og.randIntn(2) == 0 {
4990+
expr = fmt.Sprintf("(%s OR %s IS NOT NULL)", expr, columns[secondColIdx])
4991+
} else {
4992+
expr = fmt.Sprintf("(%s AND %s IS NULL)", expr, columns[secondColIdx])
4993+
}
4994+
}
4995+
case 2:
4996+
// Add a comparison with a literal
4997+
if og.randIntn(2) == 0 {
4998+
expr = fmt.Sprintf("((%s) OR (TRUE))", expr)
4999+
} else {
5000+
expr = fmt.Sprintf("(%s AND current_user = current_user)", expr)
5001+
}
5002+
}
5003+
}
5004+
5005+
return expr
5006+
}
5007+
5008+
func (og *operationGenerator) createPolicy(ctx context.Context, tx pgx.Tx) (*opStmt, error) {
5009+
tableName, err := og.randTable(ctx, tx, og.pctExisting(true), "")
5010+
if err != nil {
5011+
return nil, err
5012+
}
5013+
5014+
// Check if table exists to include appropriate expected error
5015+
tableExists, err := og.tableExists(ctx, tx, tableName)
5016+
if err != nil {
5017+
return nil, err
5018+
}
5019+
5020+
// Get columns for the table to reference in expressions
5021+
var columns []string
5022+
if tableExists {
5023+
columns, err = og.tableColumnsShuffled(ctx, tx, tableName.String())
5024+
if err != nil {
5025+
return nil, err
5026+
}
5027+
}
5028+
5029+
// Generate a unique policy name
5030+
policyName := fmt.Sprintf("policy_%s", og.newUniqueSeqNumSuffix())
5031+
5032+
// Determine which policy components to include
5033+
includeUsing := og.randIntn(2) == 0 // 50% chance to include a USING expression
5034+
includeWithCheck := og.randIntn(2) == 0 // 50% chance to include WITH CHECK
5035+
5036+
// Build the SQL statement
5037+
var sqlStatement strings.Builder
5038+
sqlStatement.WriteString(fmt.Sprintf("CREATE POLICY %s ON %s", policyName, tableName))
5039+
5040+
if includeUsing {
5041+
usingExpr := og.generatePolicyExpression(columns, -1) // -1 means no preferred column
5042+
sqlStatement.WriteString(fmt.Sprintf(" USING (%s)", usingExpr))
5043+
}
5044+
5045+
if includeWithCheck {
5046+
// Try to use a different column for WITH CHECK if possible
5047+
preferredColIdx := -1
5048+
if len(columns) > 1 && includeUsing {
5049+
preferredColIdx = og.randIntn(len(columns))
5050+
}
5051+
5052+
withCheckExpr := og.generatePolicyExpression(columns, preferredColIdx)
5053+
sqlStatement.WriteString(fmt.Sprintf(" WITH CHECK (%s)", withCheckExpr))
5054+
}
5055+
5056+
// Create the operation statement
5057+
opStmt := makeOpStmt(OpStmtDDL)
5058+
opStmt.sql = sqlStatement.String()
5059+
5060+
opStmt.expectedExecErrors.addAll(codesWithConditions{
5061+
{code: pgcode.FeatureNotSupported, condition: !og.useDeclarativeSchemaChanger},
5062+
{code: pgcode.UndefinedTable, condition: !tableExists},
5063+
})
5064+
5065+
return opStmt, nil
5066+
}
5067+
5068+
// policyInfo pairs a table name with a name of a policy
5069+
type policyInfo struct {
5070+
table tree.TableName
5071+
policyName string
5072+
}
5073+
5074+
// findExistingPolicy returns a policyInfo struct with the qualified table name and policy name.
5075+
// It also returns a boolean indicating whether a policy was found.
5076+
func findExistingPolicy(
5077+
ctx context.Context, tx pgx.Tx, og *operationGenerator,
5078+
) (*policyInfo, bool, error) {
5079+
var policyWithInfo policyInfo
5080+
policyExists := false
5081+
5082+
// Search for tables that have policies
5083+
policyTableQuery := `
5084+
SELECT
5085+
t.relname as table_name,
5086+
n.nspname as schema_name,
5087+
p.polname as policy_name
5088+
FROM
5089+
pg_policy p
5090+
JOIN pg_class t ON p.polrelid = t.oid
5091+
JOIN pg_namespace n ON t.relnamespace = n.oid
5092+
ORDER BY random()
5093+
LIMIT 1
5094+
`
5095+
5096+
rows, err := tx.Query(ctx, policyTableQuery)
5097+
if err != nil {
5098+
return nil, policyExists, err
5099+
}
5100+
defer rows.Close()
5101+
5102+
// Check if any rows were returned
5103+
for rows.Next() {
5104+
var tableName, schemaName, policyName string
5105+
if err := rows.Scan(&tableName, &schemaName, &policyName); err != nil {
5106+
return nil, policyExists, err
5107+
}
5108+
policyWithInfo = policyInfo{
5109+
table: tree.MakeTableNameFromPrefix(tree.ObjectNamePrefix{
5110+
SchemaName: tree.Name(schemaName),
5111+
ExplicitSchema: true,
5112+
}, tree.Name(tableName)),
5113+
policyName: policyName,
5114+
}
5115+
policyExists = true
5116+
}
5117+
5118+
return &policyWithInfo, policyExists, nil
5119+
}
5120+
5121+
func (og *operationGenerator) dropPolicy(ctx context.Context, tx pgx.Tx) (*opStmt, error) {
5122+
policyWithInfo, policyExists, err := findExistingPolicy(ctx, tx, og)
5123+
if err != nil {
5124+
return nil, err
5125+
}
5126+
5127+
tableExists := true
5128+
5129+
if !policyExists {
5130+
// Fall back to random table if no tables with policies were found
5131+
randomTable, err := og.randTable(ctx, tx, og.pctExisting(true), "")
5132+
if err != nil {
5133+
return nil, err
5134+
}
5135+
policyWithInfo.table = *randomTable
5136+
5137+
// If we didn't get a real policy name, generate a random one
5138+
if policyWithInfo.policyName == "" {
5139+
policyWithInfo.policyName = fmt.Sprintf("dummy_policy_%s", og.newUniqueSeqNumSuffix())
5140+
}
5141+
5142+
// Check if table exists to include appropriate expected error
5143+
tableExists, err = og.tableExists(ctx, tx, randomTable)
5144+
if err != nil {
5145+
return nil, err
5146+
}
5147+
}
5148+
5149+
// Build the SQL statement
5150+
var sqlStatement strings.Builder
5151+
sqlStatement.WriteString("DROP POLICY ")
5152+
5153+
// Randomly decide whether to include IF EXISTS (60% chance)
5154+
includeIfExists := og.randIntn(100) < 60
5155+
if includeIfExists {
5156+
sqlStatement.WriteString("IF EXISTS ")
5157+
}
5158+
5159+
sqlStatement.WriteString(fmt.Sprintf("%s ON %s", policyWithInfo.policyName, &policyWithInfo.table))
5160+
5161+
// Create the operation statement
5162+
opStmt := makeOpStmt(OpStmtDDL)
5163+
opStmt.sql = sqlStatement.String()
5164+
5165+
opStmt.expectedExecErrors.addAll(codesWithConditions{
5166+
// The policy might not exist
5167+
{code: pgcode.UndefinedObject, condition: !policyExists && !includeIfExists},
5168+
// Table might not exist
5169+
{code: pgcode.UndefinedTable, condition: !tableExists && !includeIfExists},
5170+
})
5171+
5172+
return opStmt, nil
5173+
}

pkg/workload/schemachange/optype.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ const (
122122
createTableAs // CREATE TABLE <table> AS <def>
123123
createView // CREATE VIEW <view> AS <def>
124124
createFunction // CREATE FUNCTION <function> ...
125+
createPolicy // CREATE POLICY <policy> ON <table> [TO <roles>] [USING (<using_expr>)] [WITH CHECK (<check_expr>)]
125126

126127
// COMMENT ON ...
127128

@@ -135,6 +136,7 @@ const (
135136
dropSequence // DROP SEQUENCE <sequence>
136137
dropTable // DROP TABLE <table>
137138
dropView // DROP VIEW <view>
139+
dropPolicy // DROP POLICY [IF EXISTS] <policy> ON <table>
138140

139141
// Unimplemented operations. TODO(sql-foundations): Audit and/or implement these operations.
140142
// alterDatabaseOwner
@@ -234,6 +236,7 @@ var opFuncs = []func(*operationGenerator, context.Context, pgx.Tx) (*opStmt, err
234236
commentOn: (*operationGenerator).commentOn,
235237
createFunction: (*operationGenerator).createFunction,
236238
createIndex: (*operationGenerator).createIndex,
239+
createPolicy: (*operationGenerator).createPolicy,
237240
createSchema: (*operationGenerator).createSchema,
238241
createSequence: (*operationGenerator).createSequence,
239242
createTable: (*operationGenerator).createTable,
@@ -243,6 +246,7 @@ var opFuncs = []func(*operationGenerator, context.Context, pgx.Tx) (*opStmt, err
243246
createView: (*operationGenerator).createView,
244247
dropFunction: (*operationGenerator).dropFunction,
245248
dropIndex: (*operationGenerator).dropIndex,
249+
dropPolicy: (*operationGenerator).dropPolicy,
246250
dropSchema: (*operationGenerator).dropSchema,
247251
dropSequence: (*operationGenerator).dropSequence,
248252
dropTable: (*operationGenerator).dropTable,
@@ -286,6 +290,7 @@ var opWeights = []int{
286290
commentOn: 1,
287291
createFunction: 1,
288292
createIndex: 1,
293+
createPolicy: 1,
289294
createSchema: 1,
290295
createSequence: 1,
291296
createTable: 10,
@@ -295,6 +300,7 @@ var opWeights = []int{
295300
createView: 1,
296301
dropFunction: 1,
297302
dropIndex: 1,
303+
dropPolicy: 1,
298304
dropSchema: 1,
299305
dropSequence: 1,
300306
dropTable: 1,
@@ -324,12 +330,14 @@ var opDeclarativeVersion = map[opType]clusterversion.Key{
324330
alterTableRLS: clusterversion.V25_2,
325331
alterTypeDropValue: clusterversion.MinSupported,
326332
commentOn: clusterversion.MinSupported,
327-
createIndex: clusterversion.MinSupported,
328333
createFunction: clusterversion.MinSupported,
334+
createIndex: clusterversion.MinSupported,
335+
createPolicy: clusterversion.V25_2,
329336
createSchema: clusterversion.MinSupported,
330337
createSequence: clusterversion.MinSupported,
331-
dropIndex: clusterversion.MinSupported,
332338
dropFunction: clusterversion.MinSupported,
339+
dropIndex: clusterversion.MinSupported,
340+
dropPolicy: clusterversion.V25_2,
333341
dropSchema: clusterversion.MinSupported,
334342
dropSequence: clusterversion.MinSupported,
335343
dropTable: clusterversion.MinSupported,

pkg/workload/schemachange/optype_string.go

Lines changed: 13 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)