Skip to content

Commit d8cdc2d

Browse files
knownvalue: add custom check functions for validation
1 parent 5bd829a commit d8cdc2d

18 files changed

+857
-385
lines changed

knownvalue/bool_func.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package knownvalue
5+
6+
import "fmt"
7+
8+
var _ Check = boolFunc{}
9+
10+
type boolFunc struct {
11+
checkFunc func(v bool) error
12+
}
13+
14+
// CheckValue determines whether the passed value is of type bool, and
15+
// returns no error from the provided check function
16+
func (v boolFunc) CheckValue(other any) error {
17+
val, ok := other.(bool)
18+
19+
if !ok {
20+
return fmt.Errorf("expected bool value for BoolFunc check, got: %T", other)
21+
}
22+
23+
return v.checkFunc(val)
24+
}
25+
26+
// String returns the bool representation of the value.
27+
func (v boolFunc) String() string {
28+
// Validation is up the the implementer of the function, so there are no
29+
// bool literal or regex comparers to print here
30+
return "BoolFunc"
31+
}
32+
33+
// BoolFunc returns a Check for passing the bool value in state
34+
// to the provided check function
35+
func BoolFunc(fn func(v bool) error) boolFunc {
36+
return boolFunc{
37+
checkFunc: fn,
38+
}
39+
}

knownvalue/bool_func_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package knownvalue_test
5+
6+
import (
7+
"encoding/json"
8+
"fmt"
9+
"testing"
10+
11+
"github.com/google/go-cmp/cmp"
12+
13+
"github.com/hashicorp/terraform-plugin-testing/knownvalue"
14+
)
15+
16+
func TestBoolFunc_CheckValue(t *testing.T) {
17+
t.Parallel()
18+
19+
testCases := map[string]struct {
20+
self knownvalue.Check
21+
other any
22+
expectedError error
23+
}{
24+
"nil": {
25+
self: knownvalue.BoolFunc(func(bool) error { return nil }),
26+
expectedError: fmt.Errorf("expected bool value for BoolFunc check, got: <nil>"),
27+
},
28+
"wrong-type": {
29+
self: knownvalue.BoolFunc(func(bool) error { return nil }),
30+
other: json.Number("1.234"),
31+
expectedError: fmt.Errorf("expected bool value for BoolFunc check, got: float64"),
32+
},
33+
"failure": {
34+
self: knownvalue.BoolFunc(func(b bool) error {
35+
if b != true {
36+
return fmt.Errorf("%t was not true", b)
37+
}
38+
return nil
39+
}),
40+
other: false,
41+
expectedError: fmt.Errorf("%t was not true", false),
42+
},
43+
"success": {
44+
self: knownvalue.BoolFunc(func(b bool) error {
45+
if b != true {
46+
return fmt.Errorf("%t was not foo", b)
47+
}
48+
return nil
49+
}),
50+
other: true,
51+
},
52+
}
53+
54+
for name, testCase := range testCases {
55+
name, testCase := name, testCase
56+
57+
t.Run(name, func(t *testing.T) {
58+
t.Parallel()
59+
60+
got := testCase.self.CheckValue(testCase.other)
61+
62+
if diff := cmp.Diff(got, testCase.expectedError, equateErrorMessage); diff != "" {
63+
t.Errorf("unexpected difference: %s", diff)
64+
}
65+
})
66+
}
67+
}
68+
69+
func TestBoolFunc_String(t *testing.T) {
70+
t.Parallel()
71+
72+
got := knownvalue.BoolFunc(func(bool) error { return nil }).String()
73+
74+
if diff := cmp.Diff(got, "BoolFunc"); diff != "" {
75+
t.Errorf("unexpected difference: %s", diff)
76+
}
77+
}

knownvalue/float32_func.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package knownvalue
5+
6+
import (
7+
"encoding/json"
8+
"fmt"
9+
"strconv"
10+
)
11+
12+
var _ Check = float32Func{}
13+
14+
type float32Func struct {
15+
checkFunc func(v float32) error
16+
}
17+
18+
// CheckValue determines whether the passed value is of type float32, and
19+
// returns no error from the provided check function
20+
func (v float32Func) CheckValue(other any) error {
21+
jsonNum, ok := other.(json.Number)
22+
23+
if !ok {
24+
return fmt.Errorf("expected json.Number value for Float32Func check, got: %T", other)
25+
}
26+
27+
otherVal, err := strconv.ParseFloat(string(jsonNum), 32)
28+
if err != nil {
29+
return fmt.Errorf("expected json.Number to be parseable as float32 value for Float32Func check: %s", err)
30+
}
31+
32+
return v.checkFunc(float32(otherVal))
33+
}
34+
35+
// String returns the float32 representation of the value.
36+
func (v float32Func) String() string {
37+
// Validation is up the the implementer of the function, so there are no
38+
// float32 literal or regex comparers to print here
39+
return "Float32Func"
40+
}
41+
42+
// Float32Func returns a Check for passing the float32 value in state
43+
// to the provided check function
44+
func Float32Func(fn func(v float32) error) float32Func {
45+
return float32Func{
46+
checkFunc: fn,
47+
}
48+
}

knownvalue/float32_func_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package knownvalue_test
5+
6+
import (
7+
"encoding/json"
8+
"fmt"
9+
"testing"
10+
11+
"github.com/google/go-cmp/cmp"
12+
13+
"github.com/hashicorp/terraform-plugin-testing/knownvalue"
14+
)
15+
16+
func TestFloat32Func_CheckValue(t *testing.T) {
17+
t.Parallel()
18+
19+
testCases := map[string]struct {
20+
self knownvalue.Check
21+
other any
22+
expectedError error
23+
}{
24+
"nil": {
25+
self: knownvalue.Float32Func(func(float32) error { return nil }),
26+
expectedError: fmt.Errorf("expected json.Number value for Float32Func check, got: <nil>"),
27+
},
28+
"wrong-type": {
29+
self: knownvalue.Float32Func(func(float32) error { return nil }),
30+
other: "wrongtype",
31+
expectedError: fmt.Errorf("expected json.Number value for Float32Func check, got: string"),
32+
},
33+
"failure": {
34+
self: knownvalue.Float32Func(func(f float32) error {
35+
if f != 1.1 {
36+
return fmt.Errorf("%f was not 1.1", f)
37+
}
38+
return nil
39+
}),
40+
other: json.Number("1.2"),
41+
expectedError: fmt.Errorf("%f was not 1.1", 1.2),
42+
},
43+
"success": {
44+
self: knownvalue.Float32Func(func(f float32) error {
45+
if f != 1.1 {
46+
return fmt.Errorf("%f was not 1.1", f)
47+
}
48+
return nil
49+
}),
50+
other: json.Number("1.1"),
51+
},
52+
}
53+
54+
for name, testCase := range testCases {
55+
name, testCase := name, testCase
56+
57+
t.Run(name, func(t *testing.T) {
58+
t.Parallel()
59+
60+
got := testCase.self.CheckValue(testCase.other)
61+
62+
if diff := cmp.Diff(got, testCase.expectedError, equateErrorMessage); diff != "" {
63+
t.Errorf("unexpected difference: %s", diff)
64+
}
65+
})
66+
}
67+
}
68+
69+
func TestFloat32Func_String(t *testing.T) {
70+
t.Parallel()
71+
72+
got := knownvalue.Float32Func(func(float32) error { return nil }).String()
73+
74+
if diff := cmp.Diff(got, "Float32Func"); diff != "" {
75+
t.Errorf("unexpected difference: %s", diff)
76+
}
77+
}

knownvalue/float64_func.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package knownvalue
5+
6+
import (
7+
"encoding/json"
8+
"fmt"
9+
"strconv"
10+
)
11+
12+
var _ Check = float64Func{}
13+
14+
type float64Func struct {
15+
checkFunc func(v float64) error
16+
}
17+
18+
// CheckValue determines whether the passed value is of type float64, and
19+
// returns no error from the provided check function
20+
func (v float64Func) CheckValue(other any) error {
21+
jsonNum, ok := other.(json.Number)
22+
23+
if !ok {
24+
return fmt.Errorf("expected json.Number value for Float64Func check, got: %T", other)
25+
}
26+
27+
otherVal, err := strconv.ParseFloat(string(jsonNum), 64)
28+
if err != nil {
29+
return fmt.Errorf("expected json.Number to be parseable as float64 value for Float64Func check: %s", err)
30+
}
31+
32+
return v.checkFunc(otherVal)
33+
}
34+
35+
// String returns the float64 representation of the value.
36+
func (v float64Func) String() string {
37+
// Validation is up the the implementer of the function, so there are no
38+
// float64 literal or regex comparers to print here
39+
return "Float64Func"
40+
}
41+
42+
// Float64Func returns a Check for passing the float64 value in state
43+
// to the provided check function
44+
func Float64Func(fn func(v float64) error) float64Func {
45+
return float64Func{
46+
checkFunc: fn,
47+
}
48+
}

knownvalue/float64_func_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright (c) HashiCorp, Inc.
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package knownvalue_test
5+
6+
import (
7+
"encoding/json"
8+
"fmt"
9+
"testing"
10+
11+
"github.com/google/go-cmp/cmp"
12+
13+
"github.com/hashicorp/terraform-plugin-testing/knownvalue"
14+
)
15+
16+
func TestFloat64Func_CheckValue(t *testing.T) {
17+
t.Parallel()
18+
19+
testCases := map[string]struct {
20+
self knownvalue.Check
21+
other any
22+
expectedError error
23+
}{
24+
"nil": {
25+
self: knownvalue.Float64Func(func(float64) error { return nil }),
26+
expectedError: fmt.Errorf("expected json.Number value for Float64Func check, got: <nil>"),
27+
},
28+
"wrong-type": {
29+
self: knownvalue.Float64Func(func(float64) error { return nil }),
30+
other: "wrongtype",
31+
expectedError: fmt.Errorf("expected json.Number value for Float64Func check, got: string"),
32+
},
33+
"failure": {
34+
self: knownvalue.Float64Func(func(f float64) error {
35+
if f != 1.1 {
36+
return fmt.Errorf("%f was not 1.1", f)
37+
}
38+
return nil
39+
}),
40+
other: json.Number("1.2"),
41+
expectedError: fmt.Errorf("%f was not 1.1", 1.2),
42+
},
43+
"success": {
44+
self: knownvalue.Float64Func(func(f float64) error {
45+
if f != 1.1 {
46+
return fmt.Errorf("%f was not 1.1", f)
47+
}
48+
return nil
49+
}),
50+
other: json.Number("1.1"),
51+
},
52+
}
53+
54+
for name, testCase := range testCases {
55+
name, testCase := name, testCase
56+
57+
t.Run(name, func(t *testing.T) {
58+
t.Parallel()
59+
60+
got := testCase.self.CheckValue(testCase.other)
61+
62+
if diff := cmp.Diff(got, testCase.expectedError, equateErrorMessage); diff != "" {
63+
t.Errorf("unexpected difference: %s", diff)
64+
}
65+
})
66+
}
67+
}
68+
69+
func TestFloat64Func_String(t *testing.T) {
70+
t.Parallel()
71+
72+
got := knownvalue.Float64Func(func(float64) error { return nil }).String()
73+
74+
if diff := cmp.Diff(got, "Float64Func"); diff != "" {
75+
t.Errorf("unexpected difference: %s", diff)
76+
}
77+
}

0 commit comments

Comments
 (0)