Skip to content

Commit 72df34b

Browse files
authored
ast: add scaffolding to introspect and skip compiler stages (#8304)
* ast: allow skipping stages in compiler * ast: add package function to return all stages * ast: restructure loops (metrics vs no-metrics) * ast: add withOnlyStagesUpTo() helper, use in tests I thought about exporting this, but I wasn't certain. We can do this later, if we urgently want to call this method from the outside. Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
1 parent 90d65b7 commit 72df34b

File tree

4 files changed

+617
-111
lines changed

4 files changed

+617
-111
lines changed

v1/ast/compile.go

Lines changed: 195 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ type Compiler struct {
158158
evalMode CompilerEvalMode //
159159
rewriteTestRulesForTracing bool // rewrite test rules to capture dynamic values for tracing.
160160
defaultRegoVersion RegoVersion
161+
skipStages map[StageID]struct{} // stages to skip during compilation
162+
plan *executionPlan // computed execution plan (cached)
161163
}
162164

163165
func (c *Compiler) DefaultRegoVersion() RegoVersion {
@@ -167,6 +169,87 @@ func (c *Compiler) DefaultRegoVersion() RegoVersion {
167169
// CompilerStage defines the interface for stages in the compiler.
168170
type CompilerStage func(*Compiler) *Error
169171

172+
// StageID uniquely identifies a compiler stage.
173+
type StageID string
174+
175+
// Compiler stage identifiers.
176+
const (
177+
StageResolveRefs StageID = "ResolveRefs"
178+
StageInitLocalVarGen StageID = "InitLocalVarGen"
179+
StageRewriteRuleHeadRefs StageID = "RewriteRuleHeadRefs"
180+
StageCheckKeywordOverrides StageID = "CheckKeywordOverrides"
181+
StageCheckDuplicateImports StageID = "CheckDuplicateImports"
182+
StageRemoveImports StageID = "RemoveImports"
183+
StageSetModuleTree StageID = "SetModuleTree"
184+
StageSetRuleTree StageID = "SetRuleTree"
185+
StageRewriteLocalVars StageID = "RewriteLocalVars"
186+
StageRewriteTemplateStrings StageID = "RewriteTemplateStrings"
187+
StageCheckVoidCalls StageID = "CheckVoidCalls"
188+
StageRewritePrintCalls StageID = "RewritePrintCalls"
189+
StageRewriteExprTerms StageID = "RewriteExprTerms"
190+
StageParseMetadataBlocks StageID = "ParseMetadataBlocks"
191+
StageSetAnnotationSet StageID = "SetAnnotationSet"
192+
StageRewriteRegoMetadataCalls StageID = "RewriteRegoMetadataCalls"
193+
StageSetGraph StageID = "SetGraph"
194+
StageRewriteComprehensionTerms StageID = "RewriteComprehensionTerms"
195+
StageRewriteRefsInHead StageID = "RewriteRefsInHead"
196+
StageRewriteWithValues StageID = "RewriteWithValues"
197+
StageCheckRuleConflicts StageID = "CheckRuleConflicts"
198+
StageCheckUndefinedFuncs StageID = "CheckUndefinedFuncs"
199+
StageCheckSafetyRuleHeads StageID = "CheckSafetyRuleHeads"
200+
StageCheckSafetyRuleBodies StageID = "CheckSafetyRuleBodies"
201+
StageRewriteEquals StageID = "RewriteEquals"
202+
StageRewriteDynamicTerms StageID = "RewriteDynamicTerms"
203+
StageRewriteTestRulesForTracing StageID = "RewriteTestRulesForTracing"
204+
StageCheckRecursion StageID = "CheckRecursion"
205+
StageCheckTypes StageID = "CheckTypes"
206+
StageCheckUnsafeBuiltins StageID = "CheckUnsafeBuiltins"
207+
StageCheckDeprecatedBuiltins StageID = "CheckDeprecatedBuiltins"
208+
StageBuildRuleIndices StageID = "BuildRuleIndices"
209+
StageBuildComprehensionIndices StageID = "BuildComprehensionIndices"
210+
StageBuildRequiredCapabilities StageID = "BuildRequiredCapabilities"
211+
)
212+
213+
// AllStages returns the complete list of compiler stages in execution order.
214+
func AllStages() []StageID {
215+
return []StageID{
216+
StageResolveRefs,
217+
StageInitLocalVarGen,
218+
StageRewriteRuleHeadRefs,
219+
StageCheckKeywordOverrides,
220+
StageCheckDuplicateImports,
221+
StageRemoveImports,
222+
StageSetModuleTree,
223+
StageSetRuleTree,
224+
StageRewriteLocalVars,
225+
StageRewriteTemplateStrings,
226+
StageCheckVoidCalls,
227+
StageRewritePrintCalls,
228+
StageRewriteExprTerms,
229+
StageParseMetadataBlocks,
230+
StageSetAnnotationSet,
231+
StageRewriteRegoMetadataCalls,
232+
StageSetGraph,
233+
StageRewriteComprehensionTerms,
234+
StageRewriteRefsInHead,
235+
StageRewriteWithValues,
236+
StageCheckRuleConflicts,
237+
StageCheckUndefinedFuncs,
238+
StageCheckSafetyRuleHeads,
239+
StageCheckSafetyRuleBodies,
240+
StageRewriteEquals,
241+
StageRewriteDynamicTerms,
242+
StageRewriteTestRulesForTracing,
243+
StageCheckRecursion,
244+
StageCheckTypes,
245+
StageCheckUnsafeBuiltins,
246+
StageCheckDeprecatedBuiltins,
247+
StageBuildRuleIndices,
248+
StageBuildComprehensionIndices,
249+
StageBuildRequiredCapabilities,
250+
}
251+
}
252+
170253
// CompilerEvalMode allows toggling certain stages that are only
171254
// needed for certain modes, Concretely, only "topdown" mode will
172255
// have the compiler build comprehension and rule indices.
@@ -189,6 +272,18 @@ type CompilerStageDefinition struct {
189272
Stage CompilerStage
190273
}
191274

275+
// executionPlan represents the complete ordered list of stages to execute.
276+
type executionPlan struct {
277+
stages []plannedStage
278+
}
279+
280+
// plannedStage represents a single stage in the execution plan.
281+
type plannedStage struct {
282+
name string
283+
metricName string
284+
f func()
285+
}
286+
192287
// RulesOptions defines the options for retrieving rules by Ref from the
193288
// compiler.
194289
type RulesOptions struct {
@@ -406,9 +501,34 @@ func (c *Compiler) WithPathConflictsCheckRoots(rootPaths []string) *Compiler {
406501
// the named stage.
407502
func (c *Compiler) WithStageAfter(after string, stage CompilerStageDefinition) *Compiler {
408503
c.after[after] = append(c.after[after], stage)
504+
c.plan = nil // invalidate cached plan
505+
return c
506+
}
507+
508+
// WithSkipStages configures the compiler to skip the specified stages during
509+
// compilation. This invalidates any cached execution plan.
510+
func (c *Compiler) WithSkipStages(stages ...StageID) *Compiler {
511+
if c.skipStages == nil {
512+
c.skipStages = make(map[StageID]struct{}, len(stages))
513+
}
514+
for _, s := range stages {
515+
c.skipStages[s] = struct{}{}
516+
}
517+
c.plan = nil // invalidate cached plan
409518
return c
410519
}
411520

521+
// withOnlyStagesUpTo configures the compiler to run only stages up to and
522+
// including the specified target stage. All stages after the target will be skipped.
523+
func (c *Compiler) withOnlyStagesUpTo(target StageID) *Compiler {
524+
allStages := AllStages()
525+
i := slices.Index(allStages, target)
526+
if i == -1 {
527+
return c
528+
}
529+
return c.WithSkipStages(allStages[i+1:]...)
530+
}
531+
412532
// WithMetrics will set a metrics.Metrics and be used for profiling
413533
// the Compiler instance.
414534
func (c *Compiler) WithMetrics(metrics metrics.Metrics) *Compiler {
@@ -906,6 +1026,61 @@ func (c *Compiler) WithDefaultRegoVersion(regoVersion RegoVersion) *Compiler {
9061026
return c
9071027
}
9081028

1029+
// buildExecutionPlan creates the unified list of stages to execute, including
1030+
// both main stages and "after" stages, with filtering applied.
1031+
func (c *Compiler) buildExecutionPlan() *executionPlan {
1032+
plan := &executionPlan{
1033+
stages: make([]plannedStage, 0, len(c.stages)*2),
1034+
}
1035+
1036+
for _, s := range c.stages {
1037+
if _, skip := c.skipStages[StageID(s.name)]; skip {
1038+
continue
1039+
}
1040+
1041+
plan.stages = append(plan.stages, plannedStage(s))
1042+
1043+
for _, a := range c.after[s.name] {
1044+
if _, skip := c.skipStages[StageID(a.Name)]; skip {
1045+
continue
1046+
}
1047+
1048+
afterStage := a // Capture variables in closure properly
1049+
plan.stages = append(plan.stages, plannedStage{
1050+
name: afterStage.Name,
1051+
metricName: afterStage.MetricName,
1052+
f: func() {
1053+
if err := afterStage.Stage(c); err != nil {
1054+
c.err(err)
1055+
}
1056+
},
1057+
})
1058+
}
1059+
}
1060+
1061+
return plan
1062+
}
1063+
1064+
// getOrBuildPlan ensures we have a valid execution plan.
1065+
func (c *Compiler) getOrBuildPlan() *executionPlan {
1066+
if c.plan == nil {
1067+
c.plan = c.buildExecutionPlan()
1068+
}
1069+
return c.plan
1070+
}
1071+
1072+
// StagesToRun returns the list of stage IDs that will be executed during
1073+
// compilation, in execution order. This includes both main stages and any
1074+
// registered "after" stages.
1075+
func (c *Compiler) StagesToRun() []StageID {
1076+
plan := c.getOrBuildPlan()
1077+
result := make([]StageID, len(plan.stages))
1078+
for i, s := range plan.stages {
1079+
result[i] = StageID(s.name)
1080+
}
1081+
return result
1082+
}
1083+
9091084
func (c *Compiler) counterAdd(name string, n uint64) {
9101085
if c.metrics == nil {
9111086
return
@@ -1659,42 +1834,22 @@ func (c *Compiler) checkDeprecatedBuiltins() {
16591834
}
16601835
}
16611836

1662-
func (c *Compiler) runStage(metricName string, f func()) {
1663-
if c.metrics != nil {
1664-
c.metrics.Timer(metricName).Start()
1665-
defer c.metrics.Timer(metricName).Stop()
1666-
}
1667-
f()
1668-
}
1837+
func (c *Compiler) compile() {
1838+
plan := c.getOrBuildPlan()
16691839

1670-
func (c *Compiler) runStageAfter(metricName string, s CompilerStage) *Error {
16711840
if c.metrics != nil {
1672-
c.metrics.Timer(metricName).Start()
1673-
defer c.metrics.Timer(metricName).Stop()
1674-
}
1675-
return s(c)
1676-
}
1677-
1678-
func (c *Compiler) compile() {
1679-
for _, s := range c.stages {
1680-
if c.evalMode == EvalModeIR {
1681-
switch s.name {
1682-
case "BuildRuleIndices", "BuildComprehensionIndices":
1683-
continue // skip these stages
1841+
for _, s := range plan.stages {
1842+
c.metrics.Timer(s.metricName).Start()
1843+
s.f()
1844+
c.metrics.Timer(s.metricName).Stop()
1845+
if c.Failed() {
1846+
return
16841847
}
16851848
}
1686-
1687-
if c.allowUndefinedFuncCalls && (s.name == "CheckUndefinedFuncs" || s.name == "CheckSafetyRuleBodies") {
1688-
continue
1689-
}
1690-
1691-
c.runStage(s.metricName, s.f)
1692-
if c.Failed() {
1693-
return
1694-
}
1695-
for _, a := range c.after[s.name] {
1696-
if err := c.runStageAfter(a.MetricName, a.Stage); err != nil {
1697-
c.err(err)
1849+
} else {
1850+
for _, s := range plan.stages {
1851+
s.f()
1852+
if c.Failed() {
16981853
return
16991854
}
17001855
}
@@ -1766,6 +1921,14 @@ func (c *Compiler) init() {
17661921
WithInputType(c.inputType).
17671922
Env(c.builtins)
17681923

1924+
// Configure default stage skips based on existing configuration
1925+
if c.evalMode == EvalModeIR {
1926+
c.WithSkipStages(StageBuildRuleIndices, StageBuildComprehensionIndices)
1927+
}
1928+
if c.allowUndefinedFuncCalls {
1929+
c.WithSkipStages(StageCheckUndefinedFuncs, StageCheckSafetyRuleBodies)
1930+
}
1931+
17691932
c.initialized = true
17701933
}
17711934

0 commit comments

Comments
 (0)