Skip to content

Commit 974015b

Browse files
authored
boolvalidator: add Equals validator (#232)
* boolvalidator: add Equals validator This validator can be used in cases where non-null boolean value should be exactly `true` or exactly `false`. ```console % go test -count=1 ./boolvalidator/... ok github.com/hashicorp/terraform-plugin-framework-validators/boolvalidator 0.251s ``` * Update boolvalidator/equals.go Co-authored-by: Austin Valle <[email protected]> * chore: changelog
1 parent 7979126 commit 974015b

File tree

3 files changed

+125
-0
lines changed

3 files changed

+125
-0
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
kind: FEATURES
2+
body: 'boolvalidator: Added `Equals` validator'
3+
time: 2024-09-20T16:48:52.562758-04:00
4+
custom:
5+
Issue: "232"

boolvalidator/equals.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package boolvalidator
5+
6+
import (
7+
"context"
8+
"fmt"
9+
10+
"github.com/hashicorp/terraform-plugin-framework-validators/helpers/validatordiag"
11+
"github.com/hashicorp/terraform-plugin-framework/schema/validator"
12+
"github.com/hashicorp/terraform-plugin-framework/types"
13+
)
14+
15+
var _ validator.Bool = equalsValidator{}
16+
17+
type equalsValidator struct {
18+
value types.Bool
19+
}
20+
21+
func (v equalsValidator) Description(ctx context.Context) string {
22+
return fmt.Sprintf("Value must be %q", v.value)
23+
}
24+
25+
func (v equalsValidator) MarkdownDescription(ctx context.Context) string {
26+
return v.Description(ctx)
27+
}
28+
29+
func (v equalsValidator) ValidateBool(ctx context.Context, req validator.BoolRequest, resp *validator.BoolResponse) {
30+
if req.ConfigValue.IsNull() || req.ConfigValue.IsUnknown() {
31+
return
32+
}
33+
34+
configValue := req.ConfigValue
35+
36+
if !configValue.Equal(v.value) {
37+
resp.Diagnostics.Append(validatordiag.InvalidAttributeValueMatchDiagnostic(
38+
req.Path,
39+
v.Description(ctx),
40+
configValue.String(),
41+
))
42+
}
43+
}
44+
45+
// Equals returns an AttributeValidator which ensures that the configured boolean attribute
46+
// matches the given `value`. Null (unconfigured) and unknown (known after apply) values are skipped.
47+
func Equals(value bool) validator.Bool {
48+
return equalsValidator{
49+
value: types.BoolValue(value),
50+
}
51+
}

boolvalidator/equals_test.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package boolvalidator_test
5+
6+
import (
7+
"context"
8+
"testing"
9+
10+
"github.com/hashicorp/terraform-plugin-framework-validators/boolvalidator"
11+
"github.com/hashicorp/terraform-plugin-framework/schema/validator"
12+
"github.com/hashicorp/terraform-plugin-framework/types"
13+
)
14+
15+
func TestEqualsValidator(t *testing.T) {
16+
t.Parallel()
17+
18+
type testCase struct {
19+
in types.Bool
20+
validator validator.Bool
21+
expErrors int
22+
}
23+
24+
testCases := map[string]testCase{
25+
"simple-match": {
26+
in: types.BoolValue(true),
27+
validator: boolvalidator.Equals(true),
28+
expErrors: 0,
29+
},
30+
"simple-mismatch": {
31+
in: types.BoolValue(false),
32+
validator: boolvalidator.Equals(true),
33+
expErrors: 1,
34+
},
35+
"skip-validation-on-null": {
36+
in: types.BoolNull(),
37+
validator: boolvalidator.Equals(true),
38+
expErrors: 0,
39+
},
40+
"skip-validation-on-unknown": {
41+
in: types.BoolUnknown(),
42+
validator: boolvalidator.Equals(true),
43+
expErrors: 0,
44+
},
45+
}
46+
47+
for name, test := range testCases {
48+
t.Run(name, func(t *testing.T) {
49+
t.Parallel()
50+
req := validator.BoolRequest{
51+
ConfigValue: test.in,
52+
}
53+
res := validator.BoolResponse{}
54+
test.validator.ValidateBool(context.TODO(), req, &res)
55+
56+
if test.expErrors > 0 && !res.Diagnostics.HasError() {
57+
t.Fatalf("expected %d error(s), got none", test.expErrors)
58+
}
59+
60+
if test.expErrors > 0 && test.expErrors != res.Diagnostics.ErrorsCount() {
61+
t.Fatalf("expected %d error(s), got %d: %v", test.expErrors, res.Diagnostics.ErrorsCount(), res.Diagnostics)
62+
}
63+
64+
if test.expErrors == 0 && res.Diagnostics.HasError() {
65+
t.Fatalf("expected no error(s), got %d: %v", res.Diagnostics.ErrorsCount(), res.Diagnostics)
66+
}
67+
})
68+
}
69+
}

0 commit comments

Comments
 (0)