Skip to content

Commit f1e2d0e

Browse files
committed
pass a single context throughout the device-plugin method call stack
This change follows the Go best practices. With a single ctx reference, we allow for the proper propagation of cancellations and graceful terminations across all goroutines of the device-plugin application. Signed-off-by: Tariq Ibrahim <tibrahim@nvidia.com>
1 parent da4ce42 commit f1e2d0e

File tree

5 files changed

+18
-4
lines changed

5 files changed

+18
-4
lines changed

cmd/nvidia-device-plugin/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ func startPlugins(c *cli.Context, o *options) ([]plugin.Interface, bool, error)
354354

355355
// Get the set of plugins.
356356
klog.Info("Retrieving plugins.")
357-
plugins, err := GetPlugins(infolib, nvmllib, devicelib, config)
357+
plugins, err := GetPlugins(c.Context, infolib, nvmllib, devicelib, config)
358358
if err != nil {
359359
return nil, false, fmt.Errorf("error getting plugins: %v", err)
360360
}

cmd/nvidia-device-plugin/plugin-manager.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package main
1818

1919
import (
20+
"context"
2021
"fmt"
2122

2223
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
@@ -30,7 +31,7 @@ import (
3031
)
3132

3233
// GetPlugins returns a set of plugins for the specified configuration.
33-
func GetPlugins(infolib info.Interface, nvmllib nvml.Interface, devicelib device.Interface, config *spec.Config) ([]plugin.Interface, error) {
34+
func GetPlugins(ctx context.Context, infolib info.Interface, nvmllib nvml.Interface, devicelib device.Interface, config *spec.Config) ([]plugin.Interface, error) {
3435
// TODO: We could consider passing this as an argument since it should already be used to construct nvmllib.
3536
driverRoot := root(*config.Flags.Plugin.ContainerDriverRoot)
3637

@@ -62,6 +63,7 @@ func GetPlugins(infolib info.Interface, nvmllib nvml.Interface, devicelib device
6263
}
6364

6465
plugins, err := plugin.New(infolib, nvmllib, devicelib,
66+
plugin.WithContext(ctx),
6567
plugin.WithCDIHandler(cdiHandler),
6668
plugin.WithConfig(config),
6769
plugin.WithDeviceListStrategies(deviceListStrategies),

internal/plugin/factory.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package plugin
1818

1919
import (
20+
"context"
2021
"fmt"
2122

2223
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
@@ -31,6 +32,7 @@ import (
3132
)
3233

3334
type options struct {
35+
ctx context.Context
3436
infolib info.Interface
3537
nvmllib nvml.Interface
3638
devicelib device.Interface

internal/plugin/options.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
package plugin
1818

1919
import (
20+
"context"
21+
2022
"github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
2123
"github.com/NVIDIA/go-nvml/pkg/nvml"
2224

@@ -28,6 +30,12 @@ import (
2830
// Option is a function that configures a options
2931
type Option func(*options)
3032

33+
func WithContext(ctx context.Context) Option {
34+
return func(m *options) {
35+
m.ctx = ctx
36+
}
37+
}
38+
3139
// WithCDIHandler sets the CDI handler for the options
3240
func WithCDIHandler(handler cdi.Interface) Option {
3341
return func(m *options) {

internal/plugin/server.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ const (
5050

5151
// nvidiaDevicePlugin implements the Kubernetes device plugin API
5252
type nvidiaDevicePlugin struct {
53+
ctx context.Context
5354
rm rm.ResourceManager
5455
config *spec.Config
5556
deviceListStrategies spec.DeviceListStrategies
@@ -75,6 +76,7 @@ func (o *options) devicePluginForResource(resourceManager rm.ResourceManager) (I
7576
}
7677

7778
plugin := nvidiaDevicePlugin{
79+
ctx: o.ctx,
7880
rm: resourceManager,
7981
config: o.config,
8082
deviceListStrategies: o.deviceListStrategies,
@@ -245,7 +247,7 @@ func (plugin *nvidiaDevicePlugin) Register(kubeletSocket string) error {
245247
},
246248
}
247249

248-
_, err = client.Register(context.Background(), reqt)
250+
_, err = client.Register(plugin.ctx, reqt)
249251
if err != nil {
250252
return err
251253
}
@@ -432,7 +434,7 @@ func (plugin *nvidiaDevicePlugin) PreStartContainer(context.Context, *pluginapi.
432434

433435
// dial establishes the gRPC communication with the registered device plugin.
434436
func (plugin *nvidiaDevicePlugin) dial(unixSocketPath string, timeout time.Duration) (*grpc.ClientConn, error) {
435-
ctx, cancel := context.WithTimeout(context.Background(), timeout)
437+
ctx, cancel := context.WithTimeout(plugin.ctx, timeout)
436438
defer cancel()
437439
//nolint:staticcheck // TODO: Switch to grpc.NewClient
438440
c, err := grpc.DialContext(ctx, unixSocketPath,

0 commit comments

Comments
 (0)