@@ -158,6 +158,8 @@ type Compiler struct {
158158 evalMode CompilerEvalMode //
159159 rewriteTestRulesForTracing bool // rewrite test rules to capture dynamic values for tracing.
160160 defaultRegoVersion RegoVersion
161+ skipStages map [StageID ]struct {} // stages to skip during compilation
162+ plan * executionPlan // computed execution plan (cached)
161163}
162164
163165func (c * Compiler ) DefaultRegoVersion () RegoVersion {
@@ -167,6 +169,87 @@ func (c *Compiler) DefaultRegoVersion() RegoVersion {
167169// CompilerStage defines the interface for stages in the compiler.
168170type CompilerStage func (* Compiler ) * Error
169171
172+ // StageID uniquely identifies a compiler stage.
173+ type StageID string
174+
175+ // Compiler stage identifiers.
176+ const (
177+ StageResolveRefs StageID = "ResolveRefs"
178+ StageInitLocalVarGen StageID = "InitLocalVarGen"
179+ StageRewriteRuleHeadRefs StageID = "RewriteRuleHeadRefs"
180+ StageCheckKeywordOverrides StageID = "CheckKeywordOverrides"
181+ StageCheckDuplicateImports StageID = "CheckDuplicateImports"
182+ StageRemoveImports StageID = "RemoveImports"
183+ StageSetModuleTree StageID = "SetModuleTree"
184+ StageSetRuleTree StageID = "SetRuleTree"
185+ StageRewriteLocalVars StageID = "RewriteLocalVars"
186+ StageRewriteTemplateStrings StageID = "RewriteTemplateStrings"
187+ StageCheckVoidCalls StageID = "CheckVoidCalls"
188+ StageRewritePrintCalls StageID = "RewritePrintCalls"
189+ StageRewriteExprTerms StageID = "RewriteExprTerms"
190+ StageParseMetadataBlocks StageID = "ParseMetadataBlocks"
191+ StageSetAnnotationSet StageID = "SetAnnotationSet"
192+ StageRewriteRegoMetadataCalls StageID = "RewriteRegoMetadataCalls"
193+ StageSetGraph StageID = "SetGraph"
194+ StageRewriteComprehensionTerms StageID = "RewriteComprehensionTerms"
195+ StageRewriteRefsInHead StageID = "RewriteRefsInHead"
196+ StageRewriteWithValues StageID = "RewriteWithValues"
197+ StageCheckRuleConflicts StageID = "CheckRuleConflicts"
198+ StageCheckUndefinedFuncs StageID = "CheckUndefinedFuncs"
199+ StageCheckSafetyRuleHeads StageID = "CheckSafetyRuleHeads"
200+ StageCheckSafetyRuleBodies StageID = "CheckSafetyRuleBodies"
201+ StageRewriteEquals StageID = "RewriteEquals"
202+ StageRewriteDynamicTerms StageID = "RewriteDynamicTerms"
203+ StageRewriteTestRulesForTracing StageID = "RewriteTestRulesForTracing"
204+ StageCheckRecursion StageID = "CheckRecursion"
205+ StageCheckTypes StageID = "CheckTypes"
206+ StageCheckUnsafeBuiltins StageID = "CheckUnsafeBuiltins"
207+ StageCheckDeprecatedBuiltins StageID = "CheckDeprecatedBuiltins"
208+ StageBuildRuleIndices StageID = "BuildRuleIndices"
209+ StageBuildComprehensionIndices StageID = "BuildComprehensionIndices"
210+ StageBuildRequiredCapabilities StageID = "BuildRequiredCapabilities"
211+ )
212+
213+ // AllStages returns the complete list of compiler stages in execution order.
214+ func AllStages () []StageID {
215+ return []StageID {
216+ StageResolveRefs ,
217+ StageInitLocalVarGen ,
218+ StageRewriteRuleHeadRefs ,
219+ StageCheckKeywordOverrides ,
220+ StageCheckDuplicateImports ,
221+ StageRemoveImports ,
222+ StageSetModuleTree ,
223+ StageSetRuleTree ,
224+ StageRewriteLocalVars ,
225+ StageRewriteTemplateStrings ,
226+ StageCheckVoidCalls ,
227+ StageRewritePrintCalls ,
228+ StageRewriteExprTerms ,
229+ StageParseMetadataBlocks ,
230+ StageSetAnnotationSet ,
231+ StageRewriteRegoMetadataCalls ,
232+ StageSetGraph ,
233+ StageRewriteComprehensionTerms ,
234+ StageRewriteRefsInHead ,
235+ StageRewriteWithValues ,
236+ StageCheckRuleConflicts ,
237+ StageCheckUndefinedFuncs ,
238+ StageCheckSafetyRuleHeads ,
239+ StageCheckSafetyRuleBodies ,
240+ StageRewriteEquals ,
241+ StageRewriteDynamicTerms ,
242+ StageRewriteTestRulesForTracing ,
243+ StageCheckRecursion ,
244+ StageCheckTypes ,
245+ StageCheckUnsafeBuiltins ,
246+ StageCheckDeprecatedBuiltins ,
247+ StageBuildRuleIndices ,
248+ StageBuildComprehensionIndices ,
249+ StageBuildRequiredCapabilities ,
250+ }
251+ }
252+
170253// CompilerEvalMode allows toggling certain stages that are only
171254// needed for certain modes, Concretely, only "topdown" mode will
172255// have the compiler build comprehension and rule indices.
@@ -189,6 +272,18 @@ type CompilerStageDefinition struct {
189272 Stage CompilerStage
190273}
191274
275+ // executionPlan represents the complete ordered list of stages to execute.
276+ type executionPlan struct {
277+ stages []plannedStage
278+ }
279+
280+ // plannedStage represents a single stage in the execution plan.
281+ type plannedStage struct {
282+ name string
283+ metricName string
284+ f func ()
285+ }
286+
192287// RulesOptions defines the options for retrieving rules by Ref from the
193288// compiler.
194289type RulesOptions struct {
@@ -406,9 +501,34 @@ func (c *Compiler) WithPathConflictsCheckRoots(rootPaths []string) *Compiler {
406501// the named stage.
407502func (c * Compiler ) WithStageAfter (after string , stage CompilerStageDefinition ) * Compiler {
408503 c .after [after ] = append (c .after [after ], stage )
504+ c .plan = nil // invalidate cached plan
505+ return c
506+ }
507+
508+ // WithSkipStages configures the compiler to skip the specified stages during
509+ // compilation. This invalidates any cached execution plan.
510+ func (c * Compiler ) WithSkipStages (stages ... StageID ) * Compiler {
511+ if c .skipStages == nil {
512+ c .skipStages = make (map [StageID ]struct {}, len (stages ))
513+ }
514+ for _ , s := range stages {
515+ c .skipStages [s ] = struct {}{}
516+ }
517+ c .plan = nil // invalidate cached plan
409518 return c
410519}
411520
521+ // withOnlyStagesUpTo configures the compiler to run only stages up to and
522+ // including the specified target stage. All stages after the target will be skipped.
523+ func (c * Compiler ) withOnlyStagesUpTo (target StageID ) * Compiler {
524+ allStages := AllStages ()
525+ i := slices .Index (allStages , target )
526+ if i == - 1 {
527+ return c
528+ }
529+ return c .WithSkipStages (allStages [i + 1 :]... )
530+ }
531+
412532// WithMetrics will set a metrics.Metrics and be used for profiling
413533// the Compiler instance.
414534func (c * Compiler ) WithMetrics (metrics metrics.Metrics ) * Compiler {
@@ -906,6 +1026,61 @@ func (c *Compiler) WithDefaultRegoVersion(regoVersion RegoVersion) *Compiler {
9061026 return c
9071027}
9081028
1029+ // buildExecutionPlan creates the unified list of stages to execute, including
1030+ // both main stages and "after" stages, with filtering applied.
1031+ func (c * Compiler ) buildExecutionPlan () * executionPlan {
1032+ plan := & executionPlan {
1033+ stages : make ([]plannedStage , 0 , len (c .stages )* 2 ),
1034+ }
1035+
1036+ for _ , s := range c .stages {
1037+ if _ , skip := c .skipStages [StageID (s .name )]; skip {
1038+ continue
1039+ }
1040+
1041+ plan .stages = append (plan .stages , plannedStage (s ))
1042+
1043+ for _ , a := range c .after [s .name ] {
1044+ if _ , skip := c .skipStages [StageID (a .Name )]; skip {
1045+ continue
1046+ }
1047+
1048+ afterStage := a // Capture variables in closure properly
1049+ plan .stages = append (plan .stages , plannedStage {
1050+ name : afterStage .Name ,
1051+ metricName : afterStage .MetricName ,
1052+ f : func () {
1053+ if err := afterStage .Stage (c ); err != nil {
1054+ c .err (err )
1055+ }
1056+ },
1057+ })
1058+ }
1059+ }
1060+
1061+ return plan
1062+ }
1063+
1064+ // getOrBuildPlan ensures we have a valid execution plan.
1065+ func (c * Compiler ) getOrBuildPlan () * executionPlan {
1066+ if c .plan == nil {
1067+ c .plan = c .buildExecutionPlan ()
1068+ }
1069+ return c .plan
1070+ }
1071+
1072+ // StagesToRun returns the list of stage IDs that will be executed during
1073+ // compilation, in execution order. This includes both main stages and any
1074+ // registered "after" stages.
1075+ func (c * Compiler ) StagesToRun () []StageID {
1076+ plan := c .getOrBuildPlan ()
1077+ result := make ([]StageID , len (plan .stages ))
1078+ for i , s := range plan .stages {
1079+ result [i ] = StageID (s .name )
1080+ }
1081+ return result
1082+ }
1083+
9091084func (c * Compiler ) counterAdd (name string , n uint64 ) {
9101085 if c .metrics == nil {
9111086 return
@@ -1659,42 +1834,22 @@ func (c *Compiler) checkDeprecatedBuiltins() {
16591834 }
16601835}
16611836
1662- func (c * Compiler ) runStage (metricName string , f func ()) {
1663- if c .metrics != nil {
1664- c .metrics .Timer (metricName ).Start ()
1665- defer c .metrics .Timer (metricName ).Stop ()
1666- }
1667- f ()
1668- }
1837+ func (c * Compiler ) compile () {
1838+ plan := c .getOrBuildPlan ()
16691839
1670- func (c * Compiler ) runStageAfter (metricName string , s CompilerStage ) * Error {
16711840 if c .metrics != nil {
1672- c .metrics .Timer (metricName ).Start ()
1673- defer c .metrics .Timer (metricName ).Stop ()
1674- }
1675- return s (c )
1676- }
1677-
1678- func (c * Compiler ) compile () {
1679- for _ , s := range c .stages {
1680- if c .evalMode == EvalModeIR {
1681- switch s .name {
1682- case "BuildRuleIndices" , "BuildComprehensionIndices" :
1683- continue // skip these stages
1841+ for _ , s := range plan .stages {
1842+ c .metrics .Timer (s .metricName ).Start ()
1843+ s .f ()
1844+ c .metrics .Timer (s .metricName ).Stop ()
1845+ if c .Failed () {
1846+ return
16841847 }
16851848 }
1686-
1687- if c .allowUndefinedFuncCalls && (s .name == "CheckUndefinedFuncs" || s .name == "CheckSafetyRuleBodies" ) {
1688- continue
1689- }
1690-
1691- c .runStage (s .metricName , s .f )
1692- if c .Failed () {
1693- return
1694- }
1695- for _ , a := range c .after [s .name ] {
1696- if err := c .runStageAfter (a .MetricName , a .Stage ); err != nil {
1697- c .err (err )
1849+ } else {
1850+ for _ , s := range plan .stages {
1851+ s .f ()
1852+ if c .Failed () {
16981853 return
16991854 }
17001855 }
@@ -1766,6 +1921,14 @@ func (c *Compiler) init() {
17661921 WithInputType (c .inputType ).
17671922 Env (c .builtins )
17681923
1924+ // Configure default stage skips based on existing configuration
1925+ if c .evalMode == EvalModeIR {
1926+ c .WithSkipStages (StageBuildRuleIndices , StageBuildComprehensionIndices )
1927+ }
1928+ if c .allowUndefinedFuncCalls {
1929+ c .WithSkipStages (StageCheckUndefinedFuncs , StageCheckSafetyRuleBodies )
1930+ }
1931+
17691932 c .initialized = true
17701933}
17711934
0 commit comments