Skip to content

Commit 8a4743a

Browse files
committed
fixup! refactor: Handle overrides in failure domain validation
1 parent ead35e8 commit 8a4743a

File tree

4 files changed

+365
-30
lines changed

4 files changed

+365
-30
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// Copyright 2023 Nutanix. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package variables
5+
6+
import (
7+
"encoding/json"
8+
"fmt"
9+
"maps"
10+
11+
"dario.cat/mergo"
12+
apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1"
13+
)
14+
15+
// MergeVariableOverridesWithGlobal merges the provided variable overrides with the global variables.
16+
// It performs a deep merge, ensuring that if a variable exists in both maps, the value from the overrides is used.
17+
func MergeVariableOverridesWithGlobal(
18+
overrideVars, globalVars map[string]apiextensionsv1.JSON,
19+
) (map[string]apiextensionsv1.JSON, error) {
20+
mergedVars := maps.Clone(overrideVars)
21+
22+
for k, v := range globalVars {
23+
// If the value of v is nil, skip it.
24+
if v.Raw == nil {
25+
continue
26+
}
27+
28+
existingValue, exists := mergedVars[k]
29+
30+
// If the variable does not exist in the mergedVars or the value is nil, add it and continue.
31+
if !exists || existingValue.Raw == nil {
32+
mergedVars[k] = v
33+
continue
34+
}
35+
36+
// Wrap the value in a temporary key to ensure we can unmarshal to a map.
37+
// This is necessary because the values could be scalars.
38+
tempValJSON := fmt.Sprintf(`{"value": %s}`, string(existingValue.Raw))
39+
tempGlobalValJSON := fmt.Sprintf(`{"value": %s}`, string(v.Raw))
40+
41+
// Unmarshal the existing value and the global value into maps.
42+
var val, globalVal map[string]interface{}
43+
if err := json.Unmarshal([]byte(tempValJSON), &val); err != nil {
44+
return nil, fmt.Errorf("failed to unmarshal existing value for key %q: %w", k, err)
45+
}
46+
47+
if err := json.Unmarshal([]byte(tempGlobalValJSON), &globalVal); err != nil {
48+
return nil, fmt.Errorf("failed to unmarshal global value for key %q: %w", k, err)
49+
}
50+
51+
// Now use mergo to perform a deep merge of the values, retaining the values in `val` if present.
52+
if err := mergo.Merge(&val, globalVal); err != nil {
53+
return nil, fmt.Errorf("failed to merge values for key %q: %w", k, err)
54+
}
55+
56+
// Marshal the merged value back to JSON.
57+
mergedVal, err := json.Marshal(val["value"])
58+
if err != nil {
59+
return nil, fmt.Errorf("failed to marshal merged value for key %q: %w", k, err)
60+
}
61+
62+
mergedVars[k] = apiextensionsv1.JSON{Raw: mergedVal}
63+
}
64+
65+
return mergedVars, nil
66+
}
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
// Copyright 2023 Nutanix. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package variables
5+
6+
import (
7+
"encoding/json"
8+
"testing"
9+
10+
"github.com/onsi/gomega"
11+
apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1"
12+
)
13+
14+
func TestMergeVariableOverridesWithGlobal(t *testing.T) {
15+
t.Parallel()
16+
17+
type args struct {
18+
vars map[string]apiextensionsv1.JSON
19+
globalVars map[string]apiextensionsv1.JSON
20+
}
21+
tests := []struct {
22+
name string
23+
args args
24+
want map[string]apiextensionsv1.JSON
25+
wantErr bool
26+
errString string
27+
}{
28+
{
29+
name: "no overlap, globalVars added",
30+
args: args{
31+
vars: map[string]apiextensionsv1.JSON{
32+
"a": {Raw: []byte(`1`)},
33+
},
34+
globalVars: map[string]apiextensionsv1.JSON{
35+
"b": {Raw: []byte(`2`)},
36+
},
37+
},
38+
want: map[string]apiextensionsv1.JSON{
39+
"a": {Raw: []byte(`1`)},
40+
"b": {Raw: []byte(`2`)},
41+
},
42+
},
43+
{
44+
name: "globalVars value is nil, skipped",
45+
args: args{
46+
vars: map[string]apiextensionsv1.JSON{
47+
"a": {Raw: []byte(`1`)},
48+
},
49+
globalVars: map[string]apiextensionsv1.JSON{
50+
"b": {Raw: nil},
51+
},
52+
},
53+
want: map[string]apiextensionsv1.JSON{
54+
"a": {Raw: []byte(`1`)},
55+
},
56+
},
57+
{
58+
name: "existing value is nil, globalVars value used",
59+
args: args{
60+
vars: map[string]apiextensionsv1.JSON{
61+
"a": {Raw: nil},
62+
},
63+
globalVars: map[string]apiextensionsv1.JSON{
64+
"a": {Raw: []byte(`2`)},
65+
},
66+
},
67+
want: map[string]apiextensionsv1.JSON{
68+
"a": {Raw: []byte(`2`)},
69+
},
70+
},
71+
{
72+
name: "both values are scalars, globalVars ignored",
73+
args: args{
74+
vars: map[string]apiextensionsv1.JSON{
75+
"a": {Raw: []byte(`1`)},
76+
},
77+
globalVars: map[string]apiextensionsv1.JSON{
78+
"a": {Raw: []byte(`2`)},
79+
},
80+
},
81+
want: map[string]apiextensionsv1.JSON{
82+
"a": {Raw: []byte(`1`)},
83+
},
84+
},
85+
{
86+
name: "both values are objects, merged",
87+
args: args{
88+
vars: map[string]apiextensionsv1.JSON{
89+
"a": {Raw: []byte(`{"x":1,"y":2}`)},
90+
},
91+
globalVars: map[string]apiextensionsv1.JSON{
92+
"a": {Raw: []byte(`{"y":3,"z":4}`)},
93+
},
94+
},
95+
want: map[string]apiextensionsv1.JSON{
96+
"a": {Raw: []byte(`{"x":1,"y":2,"z":4}`)},
97+
},
98+
},
99+
{
100+
name: "both values are objects with nested objects, merged",
101+
args: args{
102+
vars: map[string]apiextensionsv1.JSON{
103+
"a": {Raw: []byte(`{"x":1,"y":{"a": 2,"b":{"c": 3}}}`)},
104+
},
105+
globalVars: map[string]apiextensionsv1.JSON{
106+
"a": {Raw: []byte(`{"y":{"a": 2,"b":{"c": 5, "d": 6}},"z":4}`)},
107+
},
108+
},
109+
want: map[string]apiextensionsv1.JSON{
110+
"a": {Raw: []byte(`{"x":1,"y":{"a": 2,"b":{"c": 3, "d": 6}},"z":4}`)},
111+
},
112+
},
113+
{
114+
name: "both values are objects with nested objects with vars having nil object explicitly, merged",
115+
args: args{
116+
vars: map[string]apiextensionsv1.JSON{
117+
"a": {Raw: []byte(`{"x":1,"y":{"a": 2,"b": null}}`)},
118+
},
119+
globalVars: map[string]apiextensionsv1.JSON{
120+
"a": {Raw: []byte(`{"y":{"a": 2,"b":{"c": 5, "d": 6}},"z":4}`)},
121+
},
122+
},
123+
want: map[string]apiextensionsv1.JSON{
124+
"a": {Raw: []byte(`{"x":1,"y":{"a": 2,"b":{"c": 5, "d": 6}},"z":4}`)},
125+
},
126+
},
127+
{
128+
name: "globalVars is scalar, vars is object, keep object",
129+
args: args{
130+
vars: map[string]apiextensionsv1.JSON{
131+
"a": {Raw: []byte(`{"x":1}`)},
132+
},
133+
globalVars: map[string]apiextensionsv1.JSON{
134+
"a": {Raw: []byte(`2`)},
135+
},
136+
},
137+
want: map[string]apiextensionsv1.JSON{
138+
"a": {Raw: []byte(`{"x":1}`)},
139+
},
140+
},
141+
{
142+
name: "vars is scalar, globalVars is object, keep scalar",
143+
args: args{
144+
vars: map[string]apiextensionsv1.JSON{
145+
"a": {Raw: []byte(`2`)},
146+
},
147+
globalVars: map[string]apiextensionsv1.JSON{
148+
"a": {Raw: []byte(`{"x":1}`)},
149+
},
150+
},
151+
want: map[string]apiextensionsv1.JSON{
152+
"a": {Raw: []byte(`2`)},
153+
},
154+
},
155+
{
156+
name: "invalid JSON in vars",
157+
args: args{
158+
vars: map[string]apiextensionsv1.JSON{
159+
"a": {Raw: []byte(`{invalid}`)},
160+
},
161+
globalVars: map[string]apiextensionsv1.JSON{
162+
"a": {Raw: []byte(`{"x":1}`)},
163+
},
164+
},
165+
wantErr: true,
166+
errString: "failed to unmarshal existing value for key \"a\"",
167+
},
168+
{
169+
name: "invalid JSON in globalVars",
170+
args: args{
171+
vars: map[string]apiextensionsv1.JSON{
172+
"a": {Raw: []byte(`{"x":1}`)},
173+
},
174+
globalVars: map[string]apiextensionsv1.JSON{
175+
"a": {Raw: []byte(`{invalid}`)},
176+
},
177+
},
178+
wantErr: true,
179+
errString: "failed to unmarshal global value for key \"a\"",
180+
},
181+
}
182+
183+
for _, tt := range tests {
184+
t.Run(tt.name, func(t *testing.T) {
185+
t.Parallel()
186+
g := gomega.NewWithT(t)
187+
got, err := MergeVariableOverridesWithGlobal(tt.args.vars, tt.args.globalVars)
188+
if tt.wantErr {
189+
g.Expect(err).To(gomega.HaveOccurred())
190+
g.Expect(err.Error()).To(gomega.ContainSubstring(tt.errString))
191+
return
192+
}
193+
g.Expect(err).ToNot(gomega.HaveOccurred())
194+
// Compare JSON values
195+
for k, wantVal := range tt.want {
196+
gotVal, ok := got[k]
197+
g.Expect(ok).To(gomega.BeTrue(), "missing key %q", k)
198+
var wantObj, gotObj interface{}
199+
_ = json.Unmarshal(wantVal.Raw, &wantObj)
200+
_ = json.Unmarshal(gotVal.Raw, &gotObj)
201+
g.Expect(gotObj).To(gomega.Equal(wantObj), "key %q", k)
202+
}
203+
// Check for unexpected keys
204+
g.Expect(len(got)).To(gomega.Equal(len(tt.want)))
205+
})
206+
}
207+
}

pkg/webhook/cluster/nutanix_validator.go

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@ import (
1111
"net/netip"
1212

1313
v1 "k8s.io/api/admission/v1"
14+
apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1"
1415
"k8s.io/apimachinery/pkg/util/validation/field"
1516
clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1"
1617
ctrlclient "sigs.k8s.io/controller-runtime/pkg/client"
1718
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
1819

1920
"github.com/nutanix-cloud-native/cluster-api-runtime-extensions-nutanix/api/v1alpha1"
2021
"github.com/nutanix-cloud-native/cluster-api-runtime-extensions-nutanix/api/variables"
22+
commonvariables "github.com/nutanix-cloud-native/cluster-api-runtime-extensions-nutanix/common/pkg/capi/clustertopology/variables"
2123
"github.com/nutanix-cloud-native/cluster-api-runtime-extensions-nutanix/common/pkg/capi/utils"
2224
"github.com/nutanix-cloud-native/cluster-api-runtime-extensions-nutanix/pkg/helpers"
2325
)
@@ -220,41 +222,69 @@ func validateWorkerFailureDomainConfig(
220222
"workerConfig",
221223
)
222224

223-
// Get the machineDetails from cluster variable "workerConfig" if it is configured
224-
defaultWorkerConfig, err := variables.UnmarshalWorkerConfigVariable(cluster.Spec.Topology.Variables)
225-
if err != nil {
226-
fldErrs = append(fldErrs, field.InternalError(workerConfigVarPath,
227-
fmt.Errorf("failed to unmarshal cluster topology variable %q: %w", v1alpha1.WorkerConfigVariableName, err)))
228-
}
225+
// Merge the global cluster variables with the worker config overrides.
226+
mdVariables := commonvariables.ClusterVariablesToVariablesMap(cluster.Spec.Topology.Variables)
229227

230228
if cluster.Spec.Topology.Workers != nil {
231229
for i := range cluster.Spec.Topology.Workers.MachineDeployments {
232230
md := cluster.Spec.Topology.Workers.MachineDeployments[i]
233231
hasFailureDomain := md.FailureDomain != nil && *md.FailureDomain != ""
232+
wcfgPath := workerConfigVarPath
234233

235-
// Get the machineDetails from the overrides variable "workerConfig" if it is configured,
236-
// otherwise use the defaultWorkerConfig if it is configured.
237-
var workerConfig *variables.WorkerNodeConfigSpec
234+
// Get the md variable overrides.
235+
var mdVariableOverrides map[string]apiextensionsv1.JSON
238236
if md.Variables != nil && len(md.Variables.Overrides) > 0 {
239-
workerConfig, err = variables.UnmarshalWorkerConfigVariable(md.Variables.Overrides)
240-
if err != nil {
241-
fldErrs = append(fldErrs, field.InternalError(
242-
workerConfigMDVarOverridePath,
243-
fmt.Errorf(
244-
"failed to unmarshal worker overrides variable %q: %w",
245-
v1alpha1.WorkerConfigVariableName,
246-
err,
247-
),
248-
))
237+
wcfgPath = workerConfigMDVarOverridePath
238+
mdVariableOverrides = commonvariables.ClusterVariablesToVariablesMap(md.Variables.Overrides)
239+
240+
// If mdVariables is nil, initialize it with mdVariableOverrides, otherwise merge global and
241+
// variable overrides.
242+
if mdVariables == nil {
243+
mdVariables = mdVariableOverrides
244+
} else {
245+
// Merge global and variable overrides if global variables are present.
246+
mergedVariables, err := commonvariables.MergeVariableOverridesWithGlobal(
247+
mdVariableOverrides,
248+
mdVariables,
249+
)
250+
if err != nil {
251+
fldErrs = append(fldErrs, field.InternalError(
252+
workerConfigMDVarOverridePath,
253+
fmt.Errorf(
254+
"failed to merge global and worker variable overrides for MachineDeployment %q: %w",
255+
md.Name,
256+
err,
257+
),
258+
))
259+
}
260+
261+
mdVariables = mergedVariables
249262
}
250263
}
251264

252-
wcfgPath := workerConfigMDVarOverridePath
253-
if workerConfig == nil {
254-
workerConfig = defaultWorkerConfig
255-
wcfgPath = workerConfigVarPath
265+
workerConfigVarJSON, workerConfigPresent := mdVariables[v1alpha1.WorkerConfigVariableName]
266+
if !workerConfigPresent {
267+
continue
268+
}
269+
270+
workerConfigClusterVar := clusterv1.ClusterVariable{
271+
Name: v1alpha1.WorkerConfigVariableName,
272+
Value: *workerConfigVarJSON.DeepCopy(),
273+
}
274+
275+
workerConfig := variables.WorkerNodeConfigSpec{}
276+
if err := variables.UnmarshalClusterVariable(&workerConfigClusterVar, &workerConfig); err != nil {
277+
fldErrs = append(fldErrs, field.InternalError(
278+
workerConfigMDVarOverridePath,
279+
fmt.Errorf(
280+
"failed to unmarshal worker overrides variable %q: %w",
281+
v1alpha1.WorkerConfigVariableName,
282+
err,
283+
),
284+
))
256285
}
257-
if workerConfig == nil || workerConfig.Nutanix == nil {
286+
287+
if workerConfig.Nutanix == nil {
258288
continue
259289
}
260290

0 commit comments

Comments
 (0)