Skip to content

Commit 89a1666

Browse files
Fix assertJSONObjectsSimilar
1 parent 021c745 commit 89a1666

File tree

6 files changed

+116
-135
lines changed

6 files changed

+116
-135
lines changed

helpers_iterator.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ import (
55
"slices"
66
)
77

8-
// Map returns a new iterator of the values in the given iterator transformed using the given transform function.
9-
func Map[I, O any](values iter.Seq[I], transform func(I) O) iter.Seq[O] {
8+
// mapIter returns a new iterator of the values in the given iterator transformed using the given transform function.
9+
func mapIter[I, O any](values iter.Seq[I], transform func(I) O) iter.Seq[O] {
1010
return func(yield func(O) bool) {
1111
for value := range values {
1212
if !yield(transform(value)) {
@@ -16,7 +16,7 @@ func Map[I, O any](values iter.Seq[I], transform func(I) O) iter.Seq[O] {
1616
}
1717
}
1818

19-
// MapSlice returns a new slice of the values in the given slice transformed using the given transform function.
20-
func MapSlice[I, O any](values []I, transform func(I) O) []O {
21-
return slices.Collect(Map(slices.Values(values), transform))
19+
// mapSlice returns a new slice of the values in the given slice transformed using the given transform function.
20+
func mapSlice[I, O any](values []I, transform func(I) O) []O {
21+
return slices.Collect(mapIter(slices.Values(values), transform))
2222
}

instance_config_interfaces.go

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -127,36 +127,34 @@ func (i InstanceConfigInterface) GetCreateOptions() InstanceConfigInterfaceCreat
127127
opts.IPRanges = i.IPRanges
128128
}
129129

130-
if i.Purpose == InterfacePurposeVPC {
131-
if i.IPv4 != nil {
132-
opts.IPv4 = &VPCIPv4{
133-
VPC: i.IPv4.VPC,
134-
NAT1To1: i.IPv4.NAT1To1,
135-
}
130+
if i.IPv4 != nil {
131+
opts.IPv4 = &VPCIPv4{
132+
VPC: i.IPv4.VPC,
133+
NAT1To1: i.IPv4.NAT1To1,
136134
}
135+
}
137136

138-
if i.IPv6 != nil {
139-
ipv6 := *i.IPv6
137+
if i.IPv6 != nil {
138+
ipv6 := *i.IPv6
140139

141-
opts.IPv6 = &InstanceConfigInterfaceCreateOptionsIPv6{
142-
SLAAC: MapSlice(
143-
ipv6.SLAAC,
144-
func(i InstanceConfigInterfaceIPv6SLAAC) InstanceConfigInterfaceCreateOptionsIPv6SLAAC {
145-
return InstanceConfigInterfaceCreateOptionsIPv6SLAAC{
146-
Range: i.Range,
147-
}
148-
},
149-
),
150-
Ranges: MapSlice(
151-
ipv6.Ranges,
152-
func(i InstanceConfigInterfaceIPv6Range) InstanceConfigInterfaceCreateOptionsIPv6Range {
153-
return InstanceConfigInterfaceCreateOptionsIPv6Range{
154-
Range: copyValue(&i.Range),
155-
}
156-
},
157-
),
158-
IsPublic: copyValue(&ipv6.IsPublic),
159-
}
140+
opts.IPv6 = &InstanceConfigInterfaceCreateOptionsIPv6{
141+
SLAAC: mapSlice(
142+
ipv6.SLAAC,
143+
func(i InstanceConfigInterfaceIPv6SLAAC) InstanceConfigInterfaceCreateOptionsIPv6SLAAC {
144+
return InstanceConfigInterfaceCreateOptionsIPv6SLAAC{
145+
Range: i.Range,
146+
}
147+
},
148+
),
149+
Ranges: mapSlice(
150+
ipv6.Ranges,
151+
func(i InstanceConfigInterfaceIPv6Range) InstanceConfigInterfaceCreateOptionsIPv6Range {
152+
return InstanceConfigInterfaceCreateOptionsIPv6Range{
153+
Range: copyValue(&i.Range),
154+
}
155+
},
156+
),
157+
IsPublic: copyValue(&ipv6.IsPublic),
160158
}
161159
}
162160

@@ -181,7 +179,7 @@ func (i InstanceConfigInterface) GetUpdateOptions() InstanceConfigInterfaceUpdat
181179
if i.IPv6 != nil {
182180
ipv6 := *i.IPv6
183181

184-
newSLAAC := MapSlice(
182+
newSLAAC := mapSlice(
185183
ipv6.SLAAC,
186184
func(i InstanceConfigInterfaceIPv6SLAAC) InstanceConfigInterfaceUpdateOptionsIPv6SLAAC {
187185
return InstanceConfigInterfaceUpdateOptionsIPv6SLAAC{
@@ -190,7 +188,7 @@ func (i InstanceConfigInterface) GetUpdateOptions() InstanceConfigInterfaceUpdat
190188
},
191189
)
192190

193-
newRanges := MapSlice(
191+
newRanges := mapSlice(
194192
ipv6.Ranges,
195193
func(i InstanceConfigInterfaceIPv6Range) InstanceConfigInterfaceUpdateOptionsIPv6Range {
196194
return InstanceConfigInterfaceUpdateOptionsIPv6Range{

instance_configs.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,23 +118,27 @@ func (i *InstanceConfig) UnmarshalJSON(b []byte) error {
118118

119119
// GetCreateOptions converts a InstanceConfig to InstanceConfigCreateOptions for use in CreateInstanceConfig
120120
func (i InstanceConfig) GetCreateOptions() InstanceConfigCreateOptions {
121-
initrd := 0
122-
if i.InitRD != nil {
123-
initrd = *i.InitRD
124-
}
125-
return InstanceConfigCreateOptions{
121+
result := InstanceConfigCreateOptions{
126122
Label: i.Label,
127123
Comments: i.Comments,
128-
Devices: *i.Devices,
129124
Helpers: i.Helpers,
130125
Interfaces: getInstanceConfigInterfacesCreateOptionsList(i.Interfaces),
131126
MemoryLimit: i.MemoryLimit,
132127
Kernel: i.Kernel,
133-
InitRD: initrd,
134128
RootDevice: copyString(&i.RootDevice),
135129
RunLevel: i.RunLevel,
136130
VirtMode: i.VirtMode,
137131
}
132+
133+
if i.InitRD != nil {
134+
result.InitRD = *i.InitRD
135+
}
136+
137+
if i.Devices != nil {
138+
result.Devices = *i.Devices
139+
}
140+
141+
return result
138142
}
139143

140144
// GetUpdateOptions converts a InstanceConfig to InstanceConfigUpdateOptions for use in UpdateInstanceConfig

test/unit/util.go

Lines changed: 71 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
package unit
22

33
import (
4+
"encoding/json"
45
"reflect"
56
"slices"
67
"strconv"
78
"strings"
89
"testing"
910

10-
"github.com/stretchr/testify/assert"
1111
"github.com/stretchr/testify/require"
1212
)
1313

@@ -17,127 +17,106 @@ import (
1717
// This is primarily used to ensure that the GetCreateOptions() and GetUpdateOptions()
1818
// functions are implemented correctly.
1919
func assertJSONObjectsSimilar[TA, TB any](t testing.TB, a TA, b TB) {
20-
assertJSONObjectsSimilarInner(t, []string{}, a, b)
20+
// Encoding and decoding JSON here is hacky, but it
21+
// lets us avoid some ugly type reflection pointer logic
22+
aJSON, err := json.Marshal(a)
23+
require.NoError(t, err)
24+
25+
bJSON, err := json.Marshal(b)
26+
require.NoError(t, err)
27+
28+
var aParsed, bParsed map[string]any
29+
30+
require.NoError(t, json.Unmarshal(aJSON, &aParsed))
31+
require.NoError(t, json.Unmarshal(bJSON, &bParsed))
32+
33+
assertJSONObjectsSimilarInner(t, []string{}, aParsed, bParsed)
2134
}
2235

23-
func assertJSONObjectsSimilarInner[TA, TB any](t testing.TB, path []string, a TA, b TB) {
24-
aValue := derefValueRecursive(reflect.ValueOf(a))
25-
bValue := derefValueRecursive(reflect.ValueOf(b))
36+
func assertJSONObjectsSimilarInner(t testing.TB, path []string, a, b any) {
37+
a = normalizeEmptyValues(a)
38+
b = normalizeEmptyValues(b)
39+
40+
aValue := reflect.ValueOf(a)
41+
bValue := reflect.ValueOf(b)
2642

27-
aFields := aggregateJSONFields(reflect.ValueOf(a))
28-
bFields := aggregateJSONFields(reflect.ValueOf(b))
43+
aKind := aValue.Kind()
44+
bKind := bValue.Kind()
2945

3046
require.Equalf(
3147
t,
32-
aValue.Kind(),
33-
bValue.Kind(),
34-
"%s kind mismatch: %s != %s",
35-
path,
36-
aValue.Kind(),
37-
bValue.Kind(),
48+
aKind,
49+
bKind,
50+
"%s type mismatch: %s != %s",
51+
strings.Join(path, "."),
52+
aKind,
53+
bKind,
3854
)
3955

4056
switch aValue.Kind() {
57+
case reflect.Map:
58+
for _, key := range aValue.MapKeys() {
59+
aFieldValue := aValue.MapIndex(key)
60+
bFieldValue := bValue.MapIndex(key)
61+
62+
if !bFieldValue.IsValid() {
63+
// This key is not shared so we can ignore it
64+
continue
65+
}
66+
67+
assertJSONObjectsSimilarInner(
68+
t,
69+
slices.Concat(path, []string{key.String()}),
70+
aFieldValue.Interface(),
71+
bFieldValue.Interface(),
72+
)
73+
}
4174
case reflect.Slice:
42-
assert.Equalf(
75+
require.Equalf(
4376
t,
4477
aValue.Len(),
4578
bValue.Len(),
4679
"%s slice length mismatch: %d != %d",
47-
path,
80+
strings.Join(path, "."),
4881
aValue.Len(),
4982
bValue.Len(),
5083
)
5184

5285
for index := range aValue.Len() {
53-
assertJSONObjectsSimilarInner(
54-
t,
55-
slices.Concat(path, []string{strconv.Itoa(index)}),
56-
aValue.Index(index),
57-
bValue.Index(index),
58-
)
59-
}
60-
case reflect.Map:
61-
aKeys := aValue.MapKeys()
62-
bKeys := bValue.MapKeys()
63-
64-
assert.Equalf(
65-
t,
66-
aKeys,
67-
bKeys,
68-
"%s map keys mismatch: %v != %v",
69-
path,
70-
aKeys,
71-
bKeys,
72-
)
86+
aFieldValue := aValue.Index(index)
87+
bFieldValue := bValue.Index(index)
7388

74-
for _, key := range aKeys {
7589
assertJSONObjectsSimilarInner(
7690
t,
77-
slices.Concat(path, []string{key.String()}),
78-
aValue.MapIndex(key),
79-
bValue.MapIndex(key),
91+
slices.Concat(path, []string{strconv.Itoa(index)}),
92+
aFieldValue.Interface(),
93+
bFieldValue.Interface(),
8094
)
8195
}
82-
case reflect.Struct:
83-
for key, aFieldValue := range aFields {
84-
bFieldValue, ok := bFields[key]
85-
if !ok {
86-
// This key isn't shared, nothing to do here
87-
continue
88-
}
89-
90-
assertJSONObjectsSimilarInner(t, slices.Concat(path, []string{key}), aFieldValue, bFieldValue)
91-
}
9296
default:
93-
assert.Equal(
97+
require.Equal(
9498
t,
95-
aValue.Interface(),
96-
bValue.Interface(),
97-
"%s value mismatch: %s != %s",
98-
path,
99-
aValue.Interface(),
100-
bValue.Interface(),
99+
a,
100+
b,
101+
"%s value mismatch: %v != %v",
102+
strings.Join(path, "."),
103+
a,
104+
b,
101105
)
102106
}
103107
}
104108

105-
func aggregateJSONFields(v reflect.Value) map[string]reflect.Value {
106-
vType := derefTypeRecursive(v.Type())
107-
108-
result := make(map[string]reflect.Value, vType.NumField())
109-
110-
for fieldNum := range vType.NumField() {
111-
field := vType.Field(fieldNum)
112-
113-
jsonTag, jsonTagOk := field.Tag.Lookup("json")
114-
if !jsonTagOk {
115-
// No JSON tag is defined, nothing to do here
116-
continue
117-
}
118-
119-
if jsonTag == "-" {
120-
continue
121-
}
122-
123-
jsonTagKey := strings.Split(jsonTag, ",")[0]
124-
result[jsonTagKey] = derefValueRecursive(derefValueRecursive(v).FieldByName(field.Name))
125-
}
126-
127-
return result
128-
}
129-
130-
func derefTypeRecursive(v reflect.Type) reflect.Type {
131-
if v.Kind() == reflect.Ptr {
132-
return derefTypeRecursive(v.Elem())
133-
}
134-
135-
return v
136-
}
137-
138-
func derefValueRecursive(v reflect.Value) reflect.Value {
139-
if v.Kind() == reflect.Ptr {
140-
return derefValueRecursive(v.Elem())
109+
// normalizeEmptyValues normalizes the given value for use in JSON object diffing,
110+
// primarily replacing any map, slice, or array values with nil.
111+
//
112+
// This is necessary because an empty length-having type is functionally equivalent
113+
// to a nil value when using GetCreateOptions(...) and GetUpdateOptions(...).
114+
func normalizeEmptyValues(v any) any {
115+
vValue := reflect.ValueOf(v)
116+
vKind := vValue.Kind()
117+
118+
if (vKind == reflect.Map || vKind == reflect.Slice || vKind == reflect.Array) && vValue.Len() < 1 {
119+
return nil
141120
}
142121

143122
return v

vpc.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func (v VPC) GetCreateOptions() VPCCreateOptions {
5555
Description: v.Description,
5656
Region: v.Region,
5757
Subnets: subnetCreations,
58-
IPv6: MapSlice(v.IPv6, func(i VPCIPv6Range) VPCCreateOptionsIPv6 {
58+
IPv6: mapSlice(v.IPv6, func(i VPCIPv6Range) VPCCreateOptionsIPv6 {
5959
return VPCCreateOptionsIPv6{
6060
Range: copyValue(&i.Range),
6161
}

vpc_subnet.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func (v VPCSubnet) GetCreateOptions() VPCSubnetCreateOptions {
7070
return VPCSubnetCreateOptions{
7171
Label: v.Label,
7272
IPv4: v.IPv4,
73-
IPv6: MapSlice(v.IPv6, func(i VPCIPv6Range) VPCSubnetCreateOptionsIPv6 {
73+
IPv6: mapSlice(v.IPv6, func(i VPCIPv6Range) VPCSubnetCreateOptionsIPv6 {
7474
return VPCSubnetCreateOptionsIPv6{
7575
Range: copyValue(&i.Range),
7676
}

0 commit comments

Comments
 (0)