@@ -173,6 +173,9 @@ type CompilerStage func(*Compiler) *Error
173173type StageID string
174174
175175// Compiler stage identifiers.
176+ // Please use them when you depend on a compiler stage, like via [ast.Compiler.WithStageAfterID].
177+ // There is no guarantee that they are stable across OPA versions, but using the identifiers
178+ // at least lets you know what your attention is needed when you depend on the stages.
176179const (
177180 StageResolveRefs StageID = "ResolveRefs"
178181 StageInitLocalVarGen StageID = "InitLocalVarGen"
@@ -208,6 +211,9 @@ const (
208211 StageBuildRuleIndices StageID = "BuildRuleIndices"
209212 StageBuildComprehensionIndices StageID = "BuildComprehensionIndices"
210213 StageBuildRequiredCapabilities StageID = "BuildRequiredCapabilities"
214+
215+ // These only exist in the [ast.QueryCompiler]:
216+ StageCheckSafety StageID = "CheckSafety"
211217)
212218
213219// AllStages returns the complete list of compiler stages in execution order.
@@ -368,8 +374,15 @@ type QueryCompiler interface {
368374
369375 // WithStageAfter registers a stage to run during query compilation after
370376 // the named stage.
377+ //
378+ // Caution: Use [ast.QueryCompiler.WithStageAfterID] instead. It provides
379+ // more (Golang) compile-time safety
371380 WithStageAfter (after string , stage QueryCompilerStageDefinition ) QueryCompiler
372381
382+ // WithStageAfterID registers a stage to run during query compilation after
383+ // the named stage.
384+ WithStageAfterID (after StageID , stage QueryCompilerStageDefinition ) QueryCompiler
385+
373386 // RewrittenVars maps generated vars in the compiled query to vars from the
374387 // parsed query. For example, given the query "input := 1" the rewritten
375388 // query would be "__local0__ = 1". The mapping would then be {__local0__: input}.
@@ -394,7 +407,7 @@ type QueryCompilerStageDefinition struct {
394407}
395408
396409type stage struct {
397- name string
410+ name StageID
398411 metricName string
399412 f func ()
400413}
@@ -423,43 +436,43 @@ func NewCompiler() *Compiler {
423436 // Reference resolution should run first as it may be used to lazily
424437 // load additional modules. If any stages run before resolution, they
425438 // need to be re-run after resolution.
426- {"ResolveRefs" , "compile_stage_resolve_refs" , c .resolveAllRefs },
439+ {StageResolveRefs , "compile_stage_resolve_refs" , c .resolveAllRefs },
427440 // The local variable generator must be initialized after references are
428441 // resolved and the dynamic module loader has run but before subsequent
429442 // stages that need to generate variables.
430- {"InitLocalVarGen" , "compile_stage_init_local_var_gen" , c .initLocalVarGen },
431- {"RewriteRuleHeadRefs" , "compile_stage_rewrite_rule_head_refs" , c .rewriteRuleHeadRefs },
432- {"CheckKeywordOverrides" , "compile_stage_check_keyword_overrides" , c .checkKeywordOverrides },
433- {"CheckDuplicateImports" , "compile_stage_check_imports" , c .checkImports },
434- {"RemoveImports" , "compile_stage_remove_imports" , c .removeImports },
435- {"SetModuleTree" , "compile_stage_set_module_tree" , c .setModuleTree },
436- {"SetRuleTree" , "compile_stage_set_rule_tree" , c .setRuleTree }, // depends on RewriteRuleHeadRefs
437- {"RewriteLocalVars" , "compile_stage_rewrite_local_vars" , c .rewriteLocalVars },
438- {"RewriteTemplateStrings" , "compile_stage_rewrite_template_strings" , c .rewriteTemplateStrings },
439- {"CheckVoidCalls" , "compile_stage_check_void_calls" , c .checkVoidCalls },
440- {"RewritePrintCalls" , "compile_stage_rewrite_print_calls" , c .rewritePrintCalls },
441- {"RewriteExprTerms" , "compile_stage_rewrite_expr_terms" , c .rewriteExprTerms },
442- {"ParseMetadataBlocks" , "compile_stage_parse_metadata_blocks" , c .parseMetadataBlocks },
443- {"SetAnnotationSet" , "compile_stage_set_annotationset" , c .setAnnotationSet },
444- {"RewriteRegoMetadataCalls" , "compile_stage_rewrite_rego_metadata_calls" , c .rewriteRegoMetadataCalls },
445- {"SetGraph" , "compile_stage_set_graph" , c .setGraph },
446- {"RewriteComprehensionTerms" , "compile_stage_rewrite_comprehension_terms" , c .rewriteComprehensionTerms },
447- {"RewriteRefsInHead" , "compile_stage_rewrite_refs_in_head" , c .rewriteRefsInHead },
448- {"RewriteWithValues" , "compile_stage_rewrite_with_values" , c .rewriteWithModifiers },
449- {"CheckRuleConflicts" , "compile_stage_check_rule_conflicts" , c .checkRuleConflicts },
450- {"CheckUndefinedFuncs" , "compile_stage_check_undefined_funcs" , c .checkUndefinedFuncs },
451- {"CheckSafetyRuleHeads" , "compile_stage_check_safety_rule_heads" , c .checkSafetyRuleHeads },
452- {"CheckSafetyRuleBodies" , "compile_stage_check_safety_rule_bodies" , c .checkSafetyRuleBodies },
453- {"RewriteEquals" , "compile_stage_rewrite_equals" , c .rewriteEquals },
454- {"RewriteDynamicTerms" , "compile_stage_rewrite_dynamic_terms" , c .rewriteDynamicTerms },
455- {"RewriteTestRulesForTracing" , "compile_stage_rewrite_test_rules_for_tracing" , c .rewriteTestRuleEqualities }, // must run after RewriteDynamicTerms
456- {"CheckRecursion" , "compile_stage_check_recursion" , c .checkRecursion },
457- {"CheckTypes" , "compile_stage_check_types" , c .checkTypes }, // must be run after CheckRecursion
458- {"CheckUnsafeBuiltins" , "compile_state_check_unsafe_builtins" , c .checkUnsafeBuiltins },
459- {"CheckDeprecatedBuiltins" , "compile_state_check_deprecated_builtins" , c .checkDeprecatedBuiltins },
460- {"BuildRuleIndices" , "compile_stage_rebuild_indices" , c .buildRuleIndices },
461- {"BuildComprehensionIndices" , "compile_stage_rebuild_comprehension_indices" , c .buildComprehensionIndices },
462- {"BuildRequiredCapabilities" , "compile_stage_build_required_capabilities" , c .buildRequiredCapabilities },
443+ {StageInitLocalVarGen , "compile_stage_init_local_var_gen" , c .initLocalVarGen },
444+ {StageRewriteRuleHeadRefs , "compile_stage_rewrite_rule_head_refs" , c .rewriteRuleHeadRefs },
445+ {StageCheckKeywordOverrides , "compile_stage_check_keyword_overrides" , c .checkKeywordOverrides },
446+ {StageCheckDuplicateImports , "compile_stage_check_imports" , c .checkImports },
447+ {StageRemoveImports , "compile_stage_remove_imports" , c .removeImports },
448+ {StageSetModuleTree , "compile_stage_set_module_tree" , c .setModuleTree },
449+ {StageSetRuleTree , "compile_stage_set_rule_tree" , c .setRuleTree }, // depends on RewriteRuleHeadRefs
450+ {StageRewriteLocalVars , "compile_stage_rewrite_local_vars" , c .rewriteLocalVars },
451+ {StageRewriteTemplateStrings , "compile_stage_rewrite_template_strings" , c .rewriteTemplateStrings },
452+ {StageCheckVoidCalls , "compile_stage_check_void_calls" , c .checkVoidCalls },
453+ {StageRewritePrintCalls , "compile_stage_rewrite_print_calls" , c .rewritePrintCalls },
454+ {StageRewriteExprTerms , "compile_stage_rewrite_expr_terms" , c .rewriteExprTerms },
455+ {StageParseMetadataBlocks , "compile_stage_parse_metadata_blocks" , c .parseMetadataBlocks },
456+ {StageSetAnnotationSet , "compile_stage_set_annotationset" , c .setAnnotationSet },
457+ {StageRewriteRegoMetadataCalls , "compile_stage_rewrite_rego_metadata_calls" , c .rewriteRegoMetadataCalls },
458+ {StageSetGraph , "compile_stage_set_graph" , c .setGraph },
459+ {StageRewriteComprehensionTerms , "compile_stage_rewrite_comprehension_terms" , c .rewriteComprehensionTerms },
460+ {StageRewriteRefsInHead , "compile_stage_rewrite_refs_in_head" , c .rewriteRefsInHead },
461+ {StageRewriteWithValues , "compile_stage_rewrite_with_values" , c .rewriteWithModifiers },
462+ {StageCheckRuleConflicts , "compile_stage_check_rule_conflicts" , c .checkRuleConflicts },
463+ {StageCheckUndefinedFuncs , "compile_stage_check_undefined_funcs" , c .checkUndefinedFuncs },
464+ {StageCheckSafetyRuleHeads , "compile_stage_check_safety_rule_heads" , c .checkSafetyRuleHeads },
465+ {StageCheckSafetyRuleBodies , "compile_stage_check_safety_rule_bodies" , c .checkSafetyRuleBodies },
466+ {StageRewriteEquals , "compile_stage_rewrite_equals" , c .rewriteEquals },
467+ {StageRewriteDynamicTerms , "compile_stage_rewrite_dynamic_terms" , c .rewriteDynamicTerms },
468+ {StageRewriteTestRulesForTracing , "compile_stage_rewrite_test_rules_for_tracing" , c .rewriteTestRuleEqualities }, // must run after RewriteDynamicTerms
469+ {StageCheckRecursion , "compile_stage_check_recursion" , c .checkRecursion },
470+ {StageCheckTypes , "compile_stage_check_types" , c .checkTypes }, // must be run after CheckRecursion
471+ {StageCheckUnsafeBuiltins , "compile_state_check_unsafe_builtins" , c .checkUnsafeBuiltins },
472+ {StageCheckDeprecatedBuiltins , "compile_state_check_deprecated_builtins" , c .checkDeprecatedBuiltins },
473+ {StageBuildRuleIndices , "compile_stage_rebuild_indices" , c .buildRuleIndices },
474+ {StageBuildComprehensionIndices , "compile_stage_rebuild_comprehension_indices" , c .buildComprehensionIndices },
475+ {StageBuildRequiredCapabilities , "compile_stage_build_required_capabilities" , c .buildRequiredCapabilities },
463476 }
464477
465478 return c
@@ -499,12 +512,21 @@ func (c *Compiler) WithPathConflictsCheckRoots(rootPaths []string) *Compiler {
499512
500513// WithStageAfter registers a stage to run during compilation after
501514// the named stage.
515+ //
516+ // Caution: Consider using [ast.QueryCompiler.WithStageAfterID] instead. It provides
517+ // more (Golang) compile-time safety
502518func (c * Compiler ) WithStageAfter (after string , stage CompilerStageDefinition ) * Compiler {
503519 c .after [after ] = append (c .after [after ], stage )
504520 c .plan = nil // invalidate cached plan
505521 return c
506522}
507523
524+ // WithStageAfterID registers a stage to run during compilation after
525+ // the identified stage.
526+ func (c * Compiler ) WithStageAfterID (after StageID , stage CompilerStageDefinition ) * Compiler {
527+ return c .WithStageAfter (string (after ), stage )
528+ }
529+
508530// WithSkipStages configures the compiler to skip the specified stages during
509531// compilation. This invalidates any cached execution plan.
510532func (c * Compiler ) WithSkipStages (stages ... StageID ) * Compiler {
@@ -518,9 +540,9 @@ func (c *Compiler) WithSkipStages(stages ...StageID) *Compiler {
518540 return c
519541}
520542
521- // withOnlyStagesUpTo configures the compiler to run only stages up to and
543+ // WithOnlyStagesUpTo configures the compiler to run only stages up to and
522544// including the specified target stage. All stages after the target will be skipped.
523- func (c * Compiler ) withOnlyStagesUpTo (target StageID ) * Compiler {
545+ func (c * Compiler ) WithOnlyStagesUpTo (target StageID ) * Compiler {
524546 allStages := AllStages ()
525547 i := slices .Index (allStages , target )
526548 if i == - 1 {
@@ -1034,13 +1056,13 @@ func (c *Compiler) buildExecutionPlan() *executionPlan {
10341056 }
10351057
10361058 for _ , s := range c .stages {
1037- if _ , skip := c .skipStages [StageID ( s .name ) ]; skip {
1059+ if _ , skip := c .skipStages [s .name ]; skip {
10381060 continue
10391061 }
10401062
1041- plan .stages = append (plan .stages , plannedStage ( s ) )
1063+ plan .stages = append (plan .stages , plannedStage { name : string ( s . name ), metricName : s . metricName , f : s . f } )
10421064
1043- for _ , a := range c .after [s .name ] {
1065+ for _ , a := range c .after [string ( s .name ) ] {
10441066 if _ , skip := c .skipStages [StageID (a .Name )]; skip {
10451067 continue
10461068 }
@@ -3356,6 +3378,10 @@ func (qc *queryCompiler) WithStageAfter(after string, stage QueryCompilerStageDe
33563378 return qc
33573379}
33583380
3381+ func (qc * queryCompiler ) WithStageAfterID (after StageID , stage QueryCompilerStageDefinition ) QueryCompiler {
3382+ return qc .WithStageAfter (string (after ), stage )
3383+ }
3384+
33593385func (qc * queryCompiler ) WithUnsafeBuiltins (unsafe map [string ]struct {}) QueryCompiler {
33603386 qc .unsafeBuiltins = unsafe
33613387 return qc
@@ -3391,7 +3417,7 @@ func (qc *queryCompiler) runStageAfter(metricName string, query Body, s QueryCom
33913417}
33923418
33933419type queryStage = struct {
3394- name string
3420+ name StageID
33953421 metricName string
33963422 f func (* QueryContext , Body ) (Body , error )
33973423}
@@ -3404,21 +3430,21 @@ func (qc *queryCompiler) Compile(query Body) (Body, error) {
34043430 query = query .Copy ()
34053431
34063432 stages := []queryStage {
3407- {"CheckKeywordOverrides" , "query_compile_stage_check_keyword_overrides" , qc .checkKeywordOverrides },
3408- {"ResolveRefs" , "query_compile_stage_resolve_refs" , qc .resolveRefs },
3409- {"RewriteLocalVars" , "query_compile_stage_rewrite_local_vars" , qc .rewriteLocalVars },
3410- {"RewriteTemplateStrings" , "compile_stage_rewrite_template_strings" , qc .rewriteTemplateStrings },
3411- {"CheckVoidCalls" , "query_compile_stage_check_void_calls" , qc .checkVoidCalls },
3412- {"RewritePrintCalls" , "query_compile_stage_rewrite_print_calls" , qc .rewritePrintCalls },
3413- {"RewriteExprTerms" , "query_compile_stage_rewrite_expr_terms" , qc .rewriteExprTerms },
3414- {"RewriteComprehensionTerms" , "query_compile_stage_rewrite_comprehension_terms" , qc .rewriteComprehensionTerms },
3415- {"RewriteWithValues" , "query_compile_stage_rewrite_with_values" , qc .rewriteWithModifiers },
3416- {"CheckUndefinedFuncs" , "query_compile_stage_check_undefined_funcs" , qc .checkUndefinedFuncs },
3417- {"CheckSafety" , "query_compile_stage_check_safety" , qc .checkSafety },
3418- {"RewriteDynamicTerms" , "query_compile_stage_rewrite_dynamic_terms" , qc .rewriteDynamicTerms },
3419- {"CheckTypes" , "query_compile_stage_check_types" , qc .checkTypes },
3420- {"CheckUnsafeBuiltins" , "query_compile_stage_check_unsafe_builtins" , qc .checkUnsafeBuiltins },
3421- {"CheckDeprecatedBuiltins" , "query_compile_stage_check_deprecated_builtins" , qc .checkDeprecatedBuiltins },
3433+ {StageCheckKeywordOverrides , "query_compile_stage_check_keyword_overrides" , qc .checkKeywordOverrides },
3434+ {StageResolveRefs , "query_compile_stage_resolve_refs" , qc .resolveRefs },
3435+ {StageRewriteLocalVars , "query_compile_stage_rewrite_local_vars" , qc .rewriteLocalVars },
3436+ {StageRewriteTemplateStrings , "compile_stage_rewrite_template_strings" , qc .rewriteTemplateStrings },
3437+ {StageCheckVoidCalls , "query_compile_stage_check_void_calls" , qc .checkVoidCalls },
3438+ {StageRewritePrintCalls , "query_compile_stage_rewrite_print_calls" , qc .rewritePrintCalls },
3439+ {StageRewriteExprTerms , "query_compile_stage_rewrite_expr_terms" , qc .rewriteExprTerms },
3440+ {StageRewriteComprehensionTerms , "query_compile_stage_rewrite_comprehension_terms" , qc .rewriteComprehensionTerms },
3441+ {StageRewriteWithValues , "query_compile_stage_rewrite_with_values" , qc .rewriteWithModifiers },
3442+ {StageCheckUndefinedFuncs , "query_compile_stage_check_undefined_funcs" , qc .checkUndefinedFuncs },
3443+ {StageCheckSafety , "query_compile_stage_check_safety" , qc .checkSafety },
3444+ {StageRewriteDynamicTerms , "query_compile_stage_rewrite_dynamic_terms" , qc .rewriteDynamicTerms },
3445+ {StageCheckTypes , "query_compile_stage_check_types" , qc .checkTypes },
3446+ {StageCheckUnsafeBuiltins , "query_compile_stage_check_unsafe_builtins" , qc .checkUnsafeBuiltins },
3447+ {StageCheckDeprecatedBuiltins , "query_compile_stage_check_deprecated_builtins" , qc .checkDeprecatedBuiltins },
34223448 }
34233449 if qc .compiler .evalMode == EvalModeTopdown {
34243450 stages = append (stages , queryStage {"BuildComprehensionIndex" , "query_compile_stage_build_comprehension_index" , qc .buildComprehensionIndices })
@@ -3432,7 +3458,7 @@ func (qc *queryCompiler) Compile(query Body) (Body, error) {
34323458 if err != nil {
34333459 return nil , qc .applyErrorLimit (err )
34343460 }
3435- for _ , s := range qc .after [s .name ] {
3461+ for _ , s := range qc .after [string ( s .name ) ] {
34363462 query , err = qc .runStageAfter (s .MetricName , query , s .Stage )
34373463 if err != nil {
34383464 return nil , qc .applyErrorLimit (err )
0 commit comments