Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 33 additions & 29 deletions v1/compile/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,27 @@ func findAnnotationsForTerm(term *ast.Term, annotationRefs []*ast.AnnotationsRef
return result
}

// pruneAnnotationsAndComments filters out annotations and their associated comments based on a predicate.
// It returns the kept annotations and kept comments.
func pruneAnnotationsAndComments(
module *ast.Module,
shouldDiscard func(*ast.Annotations) bool,
) ([]*ast.Annotations, []*ast.Comment) {
keepAnnotations := slices.DeleteFunc(slices.Clone(module.Annotations), shouldDiscard)

var keepComments []*ast.Comment
for _, comment := range module.Comments {
if slices.ContainsFunc(keepAnnotations, func(a *ast.Annotations) bool {
return comment.Location.Row >= a.Location.Row &&
comment.Location.Row <= a.EndLoc().Row
}) {
keepComments = append(keepComments, comment)
}
}

return keepAnnotations, keepComments
}

// pruneBundleEntrypoints will modify modules in the provided bundle to remove
// rules matching the entrypoints along with injecting import statements to
// preserve their ability to compile.
Expand Down Expand Up @@ -816,35 +837,10 @@ func pruneBundleEntrypoints(b *bundle.Bundle, entrypointrefs []*ast.Term) error
}
}

// Drop any Annotations for rules matching the entrypoint path
var annotations []*ast.Annotations
var prunedAnnotations []*ast.Annotations
for _, annotation := range mf.Parsed.Annotations {
p := annotation.GetTargetPath()
// We prune annotations of dropped rules, but not packages, as the Rego file is always retained
if p.Equal(entrypoint.Value) && !mf.Parsed.Package.Path.Equal(entrypoint.Value) {
prunedAnnotations = append(prunedAnnotations, annotation)
} else {
annotations = append(annotations, annotation)
}
}

// Drop comments associated with pruned annotations
var comments []*ast.Comment
for _, comment := range mf.Parsed.Comments {
pruned := false
for _, annotation := range prunedAnnotations {
if comment.Location.Row >= annotation.Location.Row &&
comment.Location.Row <= annotation.EndLoc().Row {
pruned = true
break
}
}

if !pruned {
comments = append(comments, comment)
}
}
// Prune annotations and comments for entrypoint rules
annotations, comments := pruneAnnotationsAndComments(mf.Parsed, func(annotation *ast.Annotations) bool {
return annotation.GetTargetPath().Equal(entrypoint.Value)
})

// If any rules or annotations were dropped update the module accordingly
if len(rules) != len(mf.Parsed.Rules) || len(comments) != len(mf.Parsed.Comments) {
Expand Down Expand Up @@ -1215,7 +1211,15 @@ func (*optimizer) merge(a, b []bundle.ModuleFile) []bundle.ModuleFile {
}

if len(keep) > 0 {
keepAnnotations, keepComments := pruneAnnotationsAndComments(a[i].Parsed, func(annotation *ast.Annotations) bool {
return discarded.Contains(ast.NewTerm(annotation.GetTargetPath()))
})

a[i].Parsed.Rules = keep
a[i].Parsed.Annotations = keepAnnotations
a[i].Parsed.Comments = keepComments
// Remove the original raw source, we're editing the AST
// directly, so it won't be in sync anymore.
a[i].Raw = nil
b = append(b, a[i])
}
Expand Down
115 changes: 115 additions & 0 deletions v1/compile/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3531,6 +3531,121 @@ func TestOptimizerOutput(t *testing.T) {
}
}

func TestCompilerOptimizationMultipleAnnotationEntrypoints(t *testing.T) {
files := map[string]string{
"test.rego": `package pkg

# best allow rule ever
# METADATA
# entrypoint: true
allow if {
is_admin
has_permission(input.action)
}

# nice deny rule
# METADATA
# entrypoint: true
deny if {
input.user == "guest"
input.action == "delete"
}

# METADATA
# description: this rule should keep its MD
other if input.other

is_admin if input.user == "admin"
valid_actions := ["read", "write", "delete"]
has_permission(action) if action in valid_actions
`,
}

tests := []struct {
name string
optimizationLvl int
expectOptimized bool
}{
{
name: "optimization level 0",
optimizationLvl: 0,
expectOptimized: false,
},
{
name: "optimization level 1",
optimizationLvl: 1,
expectOptimized: true,
},
{
name: "optimization level 2",
optimizationLvl: 2,
expectOptimized: true,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
test.WithTestFS(files, true, func(root string, fsys fs.FS) {
compiler := New().
WithFS(fsys).
WithPaths(root).
WithOptimizationLevel(tc.optimizationLvl).
WithRegoAnnotationEntrypoints(true)
err := compiler.Build(t.Context())
if err != nil {
t.Fatal(err)
}

// Check what the original "test.rego" remnants look like at different optimization levels
var originalModule *bundle.ModuleFile
for i := range compiler.bundle.Modules {
if compiler.bundle.Modules[i].Path == "test.rego" {
originalModule = &compiler.bundle.Modules[i]
break
}
}

if originalModule == nil {
t.Fatal("expected to find original 'test.rego' in bundle")
}

// The behavior differs between optimization levels
var expectedOriginal string
switch tc.optimizationLvl {
case 0: // At level 0, all rules remain unchanged
expectedOriginal = files["test.rego"]
case 1: // At level 1, entrypoint rules removed, is_admin and has_permission inlined
expectedOriginal = `package pkg

# METADATA
# description: this rule should keep its MD
other if input.other

valid_actions := ["read", "write", "delete"]
`
case 2: // At level 2, entrypoint rules removed, support rules preserved
expectedOriginal = `package pkg

# METADATA
# description: this rule should keep its MD
other if input.other

is_admin if input.user == "admin"
valid_actions := ["read", "write", "delete"]
has_permission(action) if action in valid_actions
`
}

actualOriginal := string(originalModule.Raw)
if actualOriginal != expectedOriginal {
t.Errorf("original module mismatch at level %d:\n\nexpected:\n%s\n\ngot:\n%s",
tc.optimizationLvl, expectedOriginal, actualOriginal)
}
})
})
}
}

func TestOptimizerError(t *testing.T) {
tests := []struct {
note string
Expand Down
Loading