Skip to content

Commit 9cdabe3

Browse files
committed
Add tests
1 parent a1822bc commit 9cdabe3

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ go 1.23.4
44

55
require (
66
github.com/go-sprout/sprout v1.0.0
7+
github.com/google/go-cmp v0.6.0
78
github.com/open-policy-agent/frameworks/constraint v0.0.0-20220218180203-c2a0d8cdf85a
89
github.com/open-policy-agent/opa v1.1.0
910
github.com/sirupsen/logrus v1.9.3

internal/commands/create_test.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
package commands
2+
3+
import (
4+
"bytes"
5+
"os"
6+
"testing"
7+
8+
"github.com/google/go-cmp/cmp"
9+
log "github.com/sirupsen/logrus/hooks/test"
10+
11+
"github.com/plexsystems/konstraint/internal/rego"
12+
)
13+
14+
func TestRenderConstraint(t *testing.T) {
15+
_, entry := log.NewNullLogger()
16+
17+
violations, err := GetViolations()
18+
if err != nil {
19+
t.Errorf("Error getting violations: %v", err)
20+
}
21+
22+
expected, err := os.ReadFile("../../test/constraint_Test.yaml")
23+
if err != nil {
24+
t.Errorf("Error reading expected file: %v", err)
25+
}
26+
27+
// Need to remove carriage return for testing on Windows
28+
expected = bytes.Replace(expected, []byte("\r"), []byte(""), -1)
29+
30+
actual, err := renderConstraint(violations[0], "", entry.LastEntry())
31+
if err != nil {
32+
t.Errorf("Error rendering constraint: %v", err)
33+
}
34+
35+
if !bytes.Equal(actual, expected) {
36+
t.Errorf("Unexpected rendered template:\n %v", cmp.Diff(string(expected), string(actual)))
37+
}
38+
}
39+
40+
func TestRenderConstraintWithCustomTemplate(t *testing.T) {
41+
_, entry := log.NewNullLogger()
42+
43+
violations, err := GetViolations()
44+
if err != nil {
45+
t.Errorf("Error getting violations: %v", err)
46+
}
47+
48+
expected, err := os.ReadFile("../../test/custom/constraint_Test.yaml")
49+
if err != nil {
50+
t.Errorf("Error reading expected file: %v", err)
51+
}
52+
53+
actual, err := renderConstraint(violations[0], "constraint_template.tpl", entry.LastEntry())
54+
if err != nil {
55+
t.Errorf("Error rendering constraint: %v", err)
56+
}
57+
58+
if !bytes.Equal(actual, expected) {
59+
t.Errorf("Unexpected rendered template:\n %v", cmp.Diff(string(expected), string(actual)))
60+
}
61+
}
62+
63+
func TestRenderConstraintTemplate(t *testing.T) {
64+
_, entry := log.NewNullLogger()
65+
66+
violations, err := GetViolations()
67+
if err != nil {
68+
t.Errorf("Error getting violations: %v", err)
69+
}
70+
71+
expected, err := os.ReadFile("../../test/template_Test.yaml")
72+
if err != nil {
73+
t.Errorf("Error reading expected file: %v", err)
74+
}
75+
76+
// Need to remove carriage return for testing on windows
77+
expected = bytes.Replace(expected, []byte("\r"), []byte(""), -1)
78+
79+
actual, err := renderConstraintTemplate(violations[0], "v1beta1", "", entry.LastEntry())
80+
if err != nil {
81+
t.Errorf("Error rendering constrainttemplate: %v", err)
82+
}
83+
84+
if !bytes.Equal(actual, expected) {
85+
t.Errorf("Unexpected rendered template:\n %v", cmp.Diff(string(expected), string(actual)))
86+
}
87+
}
88+
89+
func TestRenderConstraintTemplateWithCustomTemplate(t *testing.T) {
90+
_, entry := log.NewNullLogger()
91+
92+
violations, err := GetViolations()
93+
if err != nil {
94+
t.Errorf("Error getting violations: %v", err)
95+
}
96+
97+
expected, err := os.ReadFile("../../test/custom/template_Test.yaml")
98+
if err != nil {
99+
t.Errorf("Error reading expected file: %v", err)
100+
}
101+
102+
actual, err := renderConstraintTemplate(violations[0], "v1", "constrainttemplate_template.tpl", entry.LastEntry())
103+
if err != nil {
104+
t.Errorf("Error rendering constrainttemplate: %v", err)
105+
}
106+
107+
if !bytes.Equal(actual, expected) {
108+
t.Errorf("Unexpected rendered template:\n %v", cmp.Diff(string(expected), string(actual)))
109+
}
110+
}
111+
112+
func GetViolations() ([]rego.Rego, error) {
113+
violations, err := rego.GetViolations("../../test")
114+
if err != nil {
115+
return nil, err
116+
}
117+
return violations, nil
118+
}

0 commit comments

Comments
 (0)