@@ -1655,6 +1655,14 @@ func (og *operationGenerator) dropColumn(ctx context.Context, tx pgx.Tx) (*opStm
16551655 return nil , err
16561656 }
16571657
1658+ // Check if the table has any policies
1659+ tableHasPolicies := false
1660+ if tableExists {
1661+ if tableHasPolicies , err = og .tableHasPolicies (ctx , tx , tableName ); err != nil {
1662+ return nil , err
1663+ }
1664+ }
1665+
16581666 columnName , err := og .randColumn (ctx , tx , * tableName , og .pctExisting (true ))
16591667 if err != nil {
16601668 return nil , err
@@ -1703,11 +1711,40 @@ func (og *operationGenerator) dropColumn(ctx context.Context, tx pgx.Tx) (*opStm
17031711 // It is possible the column we are dropping is in the new primary key,
17041712 // so a potential error is an invalid reference in this case.
17051713 {code : pgcode .InvalidColumnReference , condition : og .useDeclarativeSchemaChanger && hasAlterPKSchemaChange },
1714+ // It is possible that we cannot drop column because
1715+ // it is referenced in a policy expression.
1716+ {code : pgcode .InvalidTableDefinition , condition : tableHasPolicies },
17061717 })
17071718 stmt .sql = fmt .Sprintf (`ALTER TABLE %s DROP COLUMN %s` , tableName .String (), columnName .String ())
17081719 return stmt , nil
17091720}
17101721
1722+ // tableHasPolicies checks if a table has any row-level security policies defined
1723+ func (og * operationGenerator ) tableHasPolicies (
1724+ ctx context.Context , tx pgx.Tx , tableName * tree.TableName ,
1725+ ) (bool , error ) {
1726+ // Query to check if a table has any RLS policies
1727+ query := `
1728+ SELECT EXISTS (
1729+ SELECT 1
1730+ FROM pg_policy p
1731+ JOIN pg_class t ON p.polrelid = t.oid
1732+ JOIN pg_namespace n ON t.relnamespace = n.oid
1733+ WHERE t.relname = $1
1734+ AND n.nspname = $2
1735+ LIMIT 1
1736+ )
1737+ `
1738+
1739+ var hasPolicies bool
1740+ err := tx .QueryRow (ctx , query , tableName .Object (), tableName .Schema ()).Scan (& hasPolicies )
1741+ if err != nil {
1742+ return false , err
1743+ }
1744+
1745+ return hasPolicies , nil
1746+ }
1747+
17111748func (og * operationGenerator ) dropColumnDefault (ctx context.Context , tx pgx.Tx ) (* opStmt , error ) {
17121749 tableName , err := og .randTable (ctx , tx , og .pctExisting (true ), "" )
17131750 if err != nil {
@@ -4905,3 +4942,232 @@ func (og *operationGenerator) alterTableRLS(ctx context.Context, tx pgx.Tx) (*op
49054942 opStmt .sql = sqlStatement
49064943 return opStmt , nil
49074944}
4945+
4946+ // generatePolicyExpression creates a random expression suitable for use in policy USING or WITH CHECK clauses
4947+ func (og * operationGenerator ) generatePolicyExpression (
4948+ columns []string , preferredColumn int ,
4949+ ) string {
4950+ if len (columns ) == 0 {
4951+ // Fallback to simple boolean if no columns are available
4952+ if og .randIntn (2 ) == 0 {
4953+ return "true"
4954+ }
4955+ return "false"
4956+ }
4957+
4958+ // Choose a column index (or use the preferred one if valid)
4959+ colIdx := og .randIntn (len (columns ))
4960+ if preferredColumn >= 0 && preferredColumn < len (columns ) {
4961+ colIdx = preferredColumn
4962+ }
4963+
4964+ // Generate a basic expression (IS NULL or IS NOT NULL)
4965+ var expr string
4966+ if og .randIntn (2 ) == 0 {
4967+ // IS NULL check
4968+ expr = fmt .Sprintf ("%s IS NULL" , columns [colIdx ])
4969+ } else {
4970+ // IS NOT NULL check
4971+ expr = fmt .Sprintf ("%s IS NOT NULL" , columns [colIdx ])
4972+ }
4973+
4974+ // Sometimes add complexity to the expression
4975+ if og .randIntn (3 ) == 0 {
4976+ expressionType := og .randIntn (3 )
4977+ switch expressionType {
4978+ case 0 :
4979+ // Add OR TRUE/FALSE
4980+ if og .randIntn (2 ) == 0 {
4981+ expr = fmt .Sprintf ("(%s OR TRUE)" , expr )
4982+ } else {
4983+ expr = fmt .Sprintf ("(%s OR FALSE)" , expr )
4984+ }
4985+ case 1 :
4986+ // Use a different column if available for a compound expression
4987+ if len (columns ) > 1 {
4988+ secondColIdx := (colIdx + 1 ) % len (columns )
4989+ if og .randIntn (2 ) == 0 {
4990+ expr = fmt .Sprintf ("(%s OR %s IS NOT NULL)" , expr , columns [secondColIdx ])
4991+ } else {
4992+ expr = fmt .Sprintf ("(%s AND %s IS NULL)" , expr , columns [secondColIdx ])
4993+ }
4994+ }
4995+ case 2 :
4996+ // Add a comparison with a literal
4997+ if og .randIntn (2 ) == 0 {
4998+ expr = fmt .Sprintf ("((%s) OR (TRUE))" , expr )
4999+ } else {
5000+ expr = fmt .Sprintf ("(%s AND current_user = current_user)" , expr )
5001+ }
5002+ }
5003+ }
5004+
5005+ return expr
5006+ }
5007+
5008+ func (og * operationGenerator ) createPolicy (ctx context.Context , tx pgx.Tx ) (* opStmt , error ) {
5009+ tableName , err := og .randTable (ctx , tx , og .pctExisting (true ), "" )
5010+ if err != nil {
5011+ return nil , err
5012+ }
5013+
5014+ // Check if table exists to include appropriate expected error
5015+ tableExists , err := og .tableExists (ctx , tx , tableName )
5016+ if err != nil {
5017+ return nil , err
5018+ }
5019+
5020+ // Get columns for the table to reference in expressions
5021+ var columns []string
5022+ if tableExists {
5023+ columns , err = og .tableColumnsShuffled (ctx , tx , tableName .String ())
5024+ if err != nil {
5025+ return nil , err
5026+ }
5027+ }
5028+
5029+ // Generate a unique policy name
5030+ policyName := fmt .Sprintf ("policy_%s" , og .newUniqueSeqNumSuffix ())
5031+
5032+ // Determine which policy components to include
5033+ includeUsing := og .randIntn (2 ) == 0 // 50% chance to include a USING expression
5034+ includeWithCheck := og .randIntn (2 ) == 0 // 50% chance to include WITH CHECK
5035+
5036+ // Build the SQL statement
5037+ var sqlStatement strings.Builder
5038+ sqlStatement .WriteString (fmt .Sprintf ("CREATE POLICY %s ON %s" , policyName , tableName ))
5039+
5040+ if includeUsing {
5041+ usingExpr := og .generatePolicyExpression (columns , - 1 ) // -1 means no preferred column
5042+ sqlStatement .WriteString (fmt .Sprintf (" USING (%s)" , usingExpr ))
5043+ }
5044+
5045+ if includeWithCheck {
5046+ // Try to use a different column for WITH CHECK if possible
5047+ preferredColIdx := - 1
5048+ if len (columns ) > 1 && includeUsing {
5049+ preferredColIdx = og .randIntn (len (columns ))
5050+ }
5051+
5052+ withCheckExpr := og .generatePolicyExpression (columns , preferredColIdx )
5053+ sqlStatement .WriteString (fmt .Sprintf (" WITH CHECK (%s)" , withCheckExpr ))
5054+ }
5055+
5056+ // Create the operation statement
5057+ opStmt := makeOpStmt (OpStmtDDL )
5058+ opStmt .sql = sqlStatement .String ()
5059+
5060+ opStmt .expectedExecErrors .addAll (codesWithConditions {
5061+ {code : pgcode .FeatureNotSupported , condition : ! og .useDeclarativeSchemaChanger },
5062+ {code : pgcode .UndefinedTable , condition : ! tableExists },
5063+ })
5064+
5065+ return opStmt , nil
5066+ }
5067+
5068+ // policyInfo pairs a table name with a name of a policy
5069+ type policyInfo struct {
5070+ table tree.TableName
5071+ policyName string
5072+ }
5073+
5074+ // findExistingPolicy returns a policyInfo struct with the qualified table name and policy name.
5075+ // It also returns a boolean indicating whether a policy was found.
5076+ func findExistingPolicy (
5077+ ctx context.Context , tx pgx.Tx , og * operationGenerator ,
5078+ ) (* policyInfo , bool , error ) {
5079+ var policyWithInfo policyInfo
5080+ policyExists := false
5081+
5082+ // Search for tables that have policies
5083+ policyTableQuery := `
5084+ SELECT
5085+ t.relname as table_name,
5086+ n.nspname as schema_name,
5087+ p.polname as policy_name
5088+ FROM
5089+ pg_policy p
5090+ JOIN pg_class t ON p.polrelid = t.oid
5091+ JOIN pg_namespace n ON t.relnamespace = n.oid
5092+ ORDER BY random()
5093+ LIMIT 1
5094+ `
5095+
5096+ rows , err := tx .Query (ctx , policyTableQuery )
5097+ if err != nil {
5098+ return nil , policyExists , err
5099+ }
5100+ defer rows .Close ()
5101+
5102+ // Check if any rows were returned
5103+ for rows .Next () {
5104+ var tableName , schemaName , policyName string
5105+ if err := rows .Scan (& tableName , & schemaName , & policyName ); err != nil {
5106+ return nil , policyExists , err
5107+ }
5108+ policyWithInfo = policyInfo {
5109+ table : tree .MakeTableNameFromPrefix (tree.ObjectNamePrefix {
5110+ SchemaName : tree .Name (schemaName ),
5111+ ExplicitSchema : true ,
5112+ }, tree .Name (tableName )),
5113+ policyName : policyName ,
5114+ }
5115+ policyExists = true
5116+ }
5117+
5118+ return & policyWithInfo , policyExists , nil
5119+ }
5120+
5121+ func (og * operationGenerator ) dropPolicy (ctx context.Context , tx pgx.Tx ) (* opStmt , error ) {
5122+ policyWithInfo , policyExists , err := findExistingPolicy (ctx , tx , og )
5123+ if err != nil {
5124+ return nil , err
5125+ }
5126+
5127+ tableExists := true
5128+
5129+ if ! policyExists {
5130+ // Fall back to random table if no tables with policies were found
5131+ randomTable , err := og .randTable (ctx , tx , og .pctExisting (true ), "" )
5132+ if err != nil {
5133+ return nil , err
5134+ }
5135+ policyWithInfo .table = * randomTable
5136+
5137+ // If we didn't get a real policy name, generate a random one
5138+ if policyWithInfo .policyName == "" {
5139+ policyWithInfo .policyName = fmt .Sprintf ("dummy_policy_%s" , og .newUniqueSeqNumSuffix ())
5140+ }
5141+
5142+ // Check if table exists to include appropriate expected error
5143+ tableExists , err = og .tableExists (ctx , tx , randomTable )
5144+ if err != nil {
5145+ return nil , err
5146+ }
5147+ }
5148+
5149+ // Build the SQL statement
5150+ var sqlStatement strings.Builder
5151+ sqlStatement .WriteString ("DROP POLICY " )
5152+
5153+ // Randomly decide whether to include IF EXISTS (60% chance)
5154+ includeIfExists := og .randIntn (100 ) < 60
5155+ if includeIfExists {
5156+ sqlStatement .WriteString ("IF EXISTS " )
5157+ }
5158+
5159+ sqlStatement .WriteString (fmt .Sprintf ("%s ON %s" , policyWithInfo .policyName , & policyWithInfo .table ))
5160+
5161+ // Create the operation statement
5162+ opStmt := makeOpStmt (OpStmtDDL )
5163+ opStmt .sql = sqlStatement .String ()
5164+
5165+ opStmt .expectedExecErrors .addAll (codesWithConditions {
5166+ // The policy might not exist
5167+ {code : pgcode .UndefinedObject , condition : ! policyExists && ! includeIfExists },
5168+ // Table might not exist
5169+ {code : pgcode .UndefinedTable , condition : ! tableExists && ! includeIfExists },
5170+ })
5171+
5172+ return opStmt , nil
5173+ }
0 commit comments