Skip to content

Commit 8d1f64f

Browse files
committed
feat: Support context for all Nutanix client calls
1 parent e012852 commit 8d1f64f

File tree

9 files changed

+231
-90
lines changed

9 files changed

+231
-90
lines changed

pkg/webhook/preflight/nutanix/clients.go

Lines changed: 105 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ type client interface {
2626
)
2727

2828
GetImageById(
29+
ctx context.Context,
2930
uuid *string,
3031
args ...map[string]interface{},
3132
) (
@@ -34,6 +35,7 @@ type client interface {
3435
)
3536

3637
ListImages(
38+
ctx context.Context,
3739
page_ *int,
3840
limit_ *int,
3941
filter_ *string,
@@ -46,13 +48,15 @@ type client interface {
4648
)
4749

4850
GetClusterById(
51+
ctx context.Context,
4952
uuid *string,
5053
args ...map[string]interface{},
5154
) (
5255
*clustermgmtv4.GetClusterApiResponse, error,
5356
)
5457

5558
ListClusters(
59+
ctx context.Context,
5660
page_ *int,
5761
limit_ *int,
5862
filter_ *string,
@@ -65,6 +69,7 @@ type client interface {
6569
error,
6670
)
6771
ListStorageContainers(
72+
ctx context.Context,
6873
page_ *int,
6974
limit_ *int,
7075
filter_ *string,
@@ -77,13 +82,15 @@ type client interface {
7782
)
7883

7984
GetSubnetById(
85+
ctx context.Context,
8086
uuid *string,
8187
args ...map[string]interface{},
8288
) (
8389
*netv4.GetSubnetApiResponse, error,
8490
)
8591

8692
ListSubnets(
93+
ctx context.Context,
8794
page_ *int,
8895
limit_ *int,
8996
filter_ *string,
@@ -211,23 +218,23 @@ func (c *clientWrapper) GetCurrentLoggedInUser(
211218
}
212219

213220
func (c *clientWrapper) GetImageById(
221+
ctx context.Context,
214222
uuid *string,
215223
args ...map[string]interface{},
216224
) (
217225
*vmmv4.GetImageApiResponse,
218226
error,
219227
) {
220-
resp, err := c.GetImageByIdFunc(
221-
uuid,
222-
args...,
223-
)
224-
if err != nil {
225-
return nil, err
226-
}
227-
return resp, nil
228+
return callWithContext(ctx, func() (*vmmv4.GetImageApiResponse, error) {
229+
return c.GetImageByIdFunc(
230+
uuid,
231+
args...,
232+
)
233+
})
228234
}
229235

230236
func (c *clientWrapper) ListImages(
237+
ctx context.Context,
231238
page_ *int,
232239
limit_ *int,
233240
filter_ *string,
@@ -238,38 +245,36 @@ func (c *clientWrapper) ListImages(
238245
*vmmv4.ListImagesApiResponse,
239246
error,
240247
) {
241-
resp, err := c.ListImagesFunc(
242-
page_,
243-
limit_,
244-
filter_,
245-
orderby_,
246-
select_,
247-
args...,
248-
)
249-
if err != nil {
250-
return nil, err
251-
}
252-
return resp, nil
248+
return callWithContext(ctx, func() (*vmmv4.ListImagesApiResponse, error) {
249+
return c.ListImagesFunc(
250+
page_,
251+
limit_,
252+
filter_,
253+
orderby_,
254+
select_,
255+
args...,
256+
)
257+
})
253258
}
254259

255260
func (c *clientWrapper) GetClusterById(
261+
ctx context.Context,
256262
uuid *string,
257263
args ...map[string]interface{},
258264
) (
259265
*clustermgmtv4.GetClusterApiResponse,
260266
error,
261267
) {
262-
resp, err := c.GetClusterByIdFunc(
263-
uuid,
264-
args...,
265-
)
266-
if err != nil {
267-
return nil, err
268-
}
269-
return resp, nil
268+
return callWithContext(ctx, func() (*clustermgmtv4.GetClusterApiResponse, error) {
269+
return c.GetClusterByIdFunc(
270+
uuid,
271+
args...,
272+
)
273+
})
270274
}
271275

272276
func (c *clientWrapper) ListClusters(
277+
ctx context.Context,
273278
page_ *int,
274279
limit_ *int,
275280
filter_ *string,
@@ -281,22 +286,21 @@ func (c *clientWrapper) ListClusters(
281286
*clustermgmtv4.ListClustersApiResponse,
282287
error,
283288
) {
284-
resp, err := c.ListClustersFunc(
285-
page_,
286-
limit_,
287-
filter_,
288-
orderby_,
289-
apply_,
290-
select_,
291-
args...,
292-
)
293-
if err != nil {
294-
return nil, err
295-
}
296-
return resp, nil
289+
return callWithContext(ctx, func() (*clustermgmtv4.ListClustersApiResponse, error) {
290+
return c.ListClustersFunc(
291+
page_,
292+
limit_,
293+
filter_,
294+
orderby_,
295+
apply_,
296+
select_,
297+
args...,
298+
)
299+
})
297300
}
298301

299302
func (c *clientWrapper) ListStorageContainers(
303+
ctx context.Context,
300304
page_ *int,
301305
limit_ *int,
302306
filter_ *string,
@@ -307,38 +311,36 @@ func (c *clientWrapper) ListStorageContainers(
307311
*clustermgmtv4.ListStorageContainersApiResponse,
308312
error,
309313
) {
310-
resp, err := c.ListStorageContainersFunc(
311-
page_,
312-
limit_,
313-
filter_,
314-
orderby_,
315-
select_,
316-
args...,
317-
)
318-
if err != nil {
319-
return nil, err
320-
}
321-
return resp, nil
314+
return callWithContext(ctx, func() (*clustermgmtv4.ListStorageContainersApiResponse, error) {
315+
return c.ListStorageContainersFunc(
316+
page_,
317+
limit_,
318+
filter_,
319+
orderby_,
320+
select_,
321+
args...,
322+
)
323+
})
322324
}
323325

324326
func (c *clientWrapper) GetSubnetById(
327+
ctx context.Context,
325328
uuid *string,
326329
args ...map[string]interface{},
327330
) (
328331
*netv4.GetSubnetApiResponse,
329332
error,
330333
) {
331-
resp, err := c.GetSubnetByIdFunc(
332-
uuid,
333-
args...,
334-
)
335-
if err != nil {
336-
return nil, err
337-
}
338-
return resp, nil
334+
return callWithContext(ctx, func() (*netv4.GetSubnetApiResponse, error) {
335+
return c.GetSubnetByIdFunc(
336+
uuid,
337+
args...,
338+
)
339+
})
339340
}
340341

341342
func (c *clientWrapper) ListSubnets(
343+
ctx context.Context,
342344
page_ *int,
343345
limit_ *int,
344346
filter_ *string,
@@ -350,17 +352,47 @@ func (c *clientWrapper) ListSubnets(
350352
*netv4.ListSubnetsApiResponse,
351353
error,
352354
) {
353-
resp, err := c.ListSubnetsFunc(
354-
page_,
355-
limit_,
356-
filter_,
357-
orderby_,
358-
expand_,
359-
select_,
360-
args...,
361-
)
362-
if err != nil {
363-
return nil, err
355+
return callWithContext(ctx, func() (*netv4.ListSubnetsApiResponse, error) {
356+
return c.ListSubnetsFunc(
357+
page_,
358+
limit_,
359+
filter_,
360+
orderby_,
361+
expand_,
362+
select_,
363+
args...,
364+
)
365+
})
366+
}
367+
368+
// callWithContext is a helper function that immediately responds to context cancellation,
369+
// while calling a long-running, non-preemptible function. The long-running function always
370+
// runs to completion, but its result is only returned if the context is not cancelled.
371+
func callWithContext[T any](ctx context.Context, f func() (T, error)) (resp T, err error) {
372+
respCh := make(chan T)
373+
errCh := make(chan error)
374+
375+
go func() {
376+
resp, err := f()
377+
select {
378+
case <-ctx.Done():
379+
// Context was cancelled before function returned. We assume no one wants the result anymore.
380+
default:
381+
if err != nil {
382+
errCh <- err
383+
}
384+
respCh <- resp
385+
}
386+
close(respCh)
387+
close(errCh)
388+
}()
389+
390+
select {
391+
case <-ctx.Done():
392+
return resp, ctx.Err()
393+
case err := <-errCh:
394+
return resp, err
395+
case resp := <-respCh:
396+
return resp, nil
364397
}
365-
return resp, nil
366398
}

pkg/webhook/preflight/nutanix/clients_test.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ package nutanix
55

66
import (
77
"context"
8+
"errors"
9+
"testing"
10+
"time"
811

912
ctrlclient "sigs.k8s.io/controller-runtime/pkg/client"
1013
)
@@ -32,3 +35,90 @@ func (m *mockKubeClient) Get(
3235
func (m *mockKubeClient) SubResource(subResource string) ctrlclient.SubResourceClient {
3336
return m.SubResourceClient
3437
}
38+
39+
func TestCallWithContext(t *testing.T) {
40+
t.Parallel()
41+
testSuccessValue := "success"
42+
testError := errors.New("test error")
43+
44+
tests := []struct {
45+
name string
46+
ctx func() (context.Context, context.CancelFunc)
47+
f func() (string, error)
48+
wantVal string
49+
wantErr error
50+
cancelAfter time.Duration
51+
}{
52+
{
53+
name: "should return value on success",
54+
ctx: func() (context.Context, context.CancelFunc) {
55+
return context.Background(), func() {}
56+
},
57+
f: func() (string, error) {
58+
return testSuccessValue, nil
59+
},
60+
wantVal: testSuccessValue,
61+
wantErr: nil,
62+
},
63+
{
64+
name: "should return error when function fails",
65+
ctx: func() (context.Context, context.CancelFunc) {
66+
return context.Background(), func() {}
67+
},
68+
f: func() (string, error) {
69+
return "", testError
70+
},
71+
wantErr: testError,
72+
},
73+
{
74+
name: "should return context error when context is cancelled during execution",
75+
ctx: func() (context.Context, context.CancelFunc) {
76+
return context.WithCancel(context.Background())
77+
},
78+
f: func() (string, error) {
79+
time.Sleep(100 * time.Millisecond)
80+
return testSuccessValue, nil
81+
},
82+
wantErr: context.Canceled,
83+
cancelAfter: 10 * time.Millisecond,
84+
},
85+
{
86+
name: "should return context error when context is already cancelled",
87+
ctx: func() (context.Context, context.CancelFunc) {
88+
ctx, cancel := context.WithCancel(context.Background())
89+
cancel()
90+
return ctx, func() {}
91+
},
92+
f: func() (string, error) {
93+
t.Log("this function should not have its result returned")
94+
return testSuccessValue, nil
95+
},
96+
wantErr: context.Canceled,
97+
},
98+
}
99+
100+
for _, tt := range tests {
101+
t.Run(tt.name, func(t *testing.T) {
102+
t.Parallel()
103+
ctx, cancel := tt.ctx()
104+
defer cancel()
105+
106+
if tt.cancelAfter > 0 {
107+
go func() {
108+
time.Sleep(tt.cancelAfter)
109+
cancel()
110+
}()
111+
}
112+
113+
gotVal, gotErr := callWithContext(ctx, tt.f)
114+
115+
if !errors.Is(gotErr, tt.wantErr) {
116+
t.Errorf("callWithContext() error = %v, wantErr %v", gotErr, tt.wantErr)
117+
}
118+
119+
if gotVal != tt.wantVal {
120+
t.Errorf("callWithContext() gotVal = %s, wantVal %s", gotVal, tt.wantVal)
121+
}
122+
})
123+
}
124+
}

0 commit comments

Comments
 (0)