Skip to content

Commit c476d34

Browse files
Merge pull request #1967 from karthikvetrivel/fix/upgrade-controller-label-predicate
Upgrade Controller: filter node watch to only upgrade state label changes
2 parents e0ba4ab + 183300b commit c476d34

File tree

2 files changed

+127
-3
lines changed

2 files changed

+127
-3
lines changed

controllers/state_manager_test.go

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,120 @@ func TestGetRuntimeString(t *testing.T) {
6969
})
7070
}
7171
}
72+
73+
func TestIsValidWorkloadConfig(t *testing.T) {
74+
tests := []struct {
75+
config string
76+
want bool
77+
}{
78+
{gpuWorkloadConfigContainer, true}, {gpuWorkloadConfigVMPassthrough, true}, {gpuWorkloadConfigVMVgpu, true},
79+
{"invalid", false}, {"", false},
80+
}
81+
for _, tc := range tests {
82+
if got := isValidWorkloadConfig(tc.config); got != tc.want {
83+
t.Errorf("isValidWorkloadConfig(%q) = %v, want %v", tc.config, got, tc.want)
84+
}
85+
}
86+
}
87+
88+
func TestHasOperandsDisabled(t *testing.T) {
89+
tests := []struct {
90+
labels map[string]string
91+
want bool
92+
}{
93+
{map[string]string{commonOperandsLabelKey: "false"}, true},
94+
{map[string]string{commonOperandsLabelKey: commonOperandsLabelValue}, false},
95+
{map[string]string{}, false},
96+
}
97+
for _, tc := range tests {
98+
if got := hasOperandsDisabled(tc.labels); got != tc.want {
99+
t.Errorf("hasOperandsDisabled(%v) = %v, want %v", tc.labels, got, tc.want)
100+
}
101+
}
102+
}
103+
104+
func TestHasNFDLabels(t *testing.T) {
105+
tests := []struct {
106+
labels map[string]string
107+
want bool
108+
}{
109+
{map[string]string{nfdLabelPrefix + "cpu": "true"}, true},
110+
{map[string]string{"other-label": "value"}, false},
111+
{map[string]string{}, false},
112+
}
113+
for _, tc := range tests {
114+
if got := hasNFDLabels(tc.labels); got != tc.want {
115+
t.Errorf("hasNFDLabels(%v) = %v, want %v", tc.labels, got, tc.want)
116+
}
117+
}
118+
}
119+
120+
func TestHasMIGManagerLabel(t *testing.T) {
121+
tests := []struct {
122+
labels map[string]string
123+
want bool
124+
}{
125+
{map[string]string{migManagerLabelKey: migManagerLabelValue}, true},
126+
{map[string]string{"other": "value"}, false},
127+
}
128+
for _, tc := range tests {
129+
if got := hasMIGManagerLabel(tc.labels); got != tc.want {
130+
t.Errorf("hasMIGManagerLabel(%v) = %v, want %v", tc.labels, got, tc.want)
131+
}
132+
}
133+
}
134+
135+
func TestHasCommonGPULabel(t *testing.T) {
136+
tests := []struct {
137+
labels map[string]string
138+
want bool
139+
}{
140+
{map[string]string{commonGPULabelKey: commonGPULabelValue}, true},
141+
{map[string]string{commonGPULabelKey: "false"}, false},
142+
{map[string]string{}, false},
143+
}
144+
for _, tc := range tests {
145+
if got := hasCommonGPULabel(tc.labels); got != tc.want {
146+
t.Errorf("hasCommonGPULabel(%v) = %v, want %v", tc.labels, got, tc.want)
147+
}
148+
}
149+
}
150+
151+
func TestHasGPULabels(t *testing.T) {
152+
tests := []struct {
153+
labels map[string]string
154+
want bool
155+
}{
156+
{map[string]string{nfdLabelPrefix + "pci-10de.present": "true"}, true},
157+
{map[string]string{nfdLabelPrefix + "pci-0302_10de.present": "true"}, true},
158+
{map[string]string{nfdLabelPrefix + "pci-0300_10de.present": "true"}, true},
159+
{map[string]string{nfdLabelPrefix + "pci-10de.present": "false"}, false},
160+
{map[string]string{"other": "true"}, false},
161+
}
162+
for _, tc := range tests {
163+
if got := hasGPULabels(tc.labels); got != tc.want {
164+
t.Errorf("hasGPULabels(%v) = %v, want %v", tc.labels, got, tc.want)
165+
}
166+
}
167+
}
168+
169+
func TestHasMIGCapableGPU(t *testing.T) {
170+
tests := []struct {
171+
labels map[string]string
172+
want bool
173+
}{
174+
{map[string]string{migCapableLabelKey: migCapableLabelValue}, true},
175+
{map[string]string{migCapableLabelKey: "false"}, false},
176+
{map[string]string{gpuProductLabelKey: "NVIDIA-A100"}, true},
177+
{map[string]string{gpuProductLabelKey: "NVIDIA-H100"}, true},
178+
{map[string]string{gpuProductLabelKey: "NVIDIA-A30"}, true},
179+
{map[string]string{gpuProductLabelKey: "NVIDIA-T4"}, false},
180+
{map[string]string{vgpuHostDriverLabelKey: "535.54"}, false},
181+
{map[string]string{}, false},
182+
}
183+
for _, tc := range tests {
184+
if got := hasMIGCapableGPU(tc.labels); got != tc.want {
185+
t.Errorf("hasMIGCapableGPU(%v) = %v, want %v", tc.labels, got, tc.want)
186+
}
187+
}
188+
}

controllers/upgrade_controller.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"k8s.io/apimachinery/pkg/util/intstr"
3030
"k8s.io/client-go/util/workqueue"
3131
"sigs.k8s.io/controller-runtime/pkg/controller"
32+
"sigs.k8s.io/controller-runtime/pkg/event"
3233
"sigs.k8s.io/controller-runtime/pkg/handler"
3334
"sigs.k8s.io/controller-runtime/pkg/log"
3435
"sigs.k8s.io/controller-runtime/pkg/predicate"
@@ -254,14 +255,20 @@ func (r *UpgradeReconciler) SetupWithManager(ctx context.Context, mgr ctrl.Manag
254255
return getClusterPoliciesToReconcile(ctx, mgr.GetClient())
255256
}
256257

257-
// Watch for changes to node labels
258-
// TODO: only watch for changes to upgrade state label
258+
// Only watch for changes to the upgrade state label
259+
upgradeStateLabelPredicate := predicate.TypedFuncs[*corev1.Node]{
260+
UpdateFunc: func(e event.TypedUpdateEvent[*corev1.Node]) bool {
261+
label := upgrade.GetUpgradeStateLabelKey()
262+
return e.ObjectOld.Labels[label] != e.ObjectNew.Labels[label]
263+
},
264+
}
265+
259266
err = c.Watch(
260267
source.Kind(
261268
mgr.GetCache(),
262269
&corev1.Node{},
263270
handler.TypedEnqueueRequestsFromMapFunc[*corev1.Node](nodeMapFn),
264-
predicate.TypedLabelChangedPredicate[*corev1.Node]{},
271+
upgradeStateLabelPredicate,
265272
),
266273
)
267274
if err != nil {

0 commit comments

Comments
 (0)