Skip to content

Commit a01f0a0

Browse files
thatmattlongaggarwal0009timraymond
authored
feat: add device plugin support to CNS (#2886)
* feat: add device plugin support to CNS * Add UT coverage and linter fixes * fix windows-latest lint issues * Update cns/deviceplugin/pluginmanager.go Co-authored-by: Timothy J. Raymond <[email protected]> Signed-off-by: aggarwal0009 <[email protected]> * remove test run output file * linter fixes * resolve readability related comments * move nolint annotations inline * remove unnecessary nolint * update UT * deleted gitignore for test file * fix goroutine eak * pr feedback cleanup * move devicePrefix to a constant * pr refactoring * refactored to make PluginManager generic * Update trackDevice to return nil * Add documentation * fix shutdownCh initialization in server.go * Fix UTs * fix merge conflict errors --------- Signed-off-by: aggarwal0009 <[email protected]> Co-authored-by: aggarwal0009 <[email protected]> Co-authored-by: Timothy J. Raymond <[email protected]>
1 parent 4971fb6 commit a01f0a0

File tree

9 files changed

+1024
-37
lines changed

9 files changed

+1024
-37
lines changed

cns/deviceplugin/plugin.go

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
package deviceplugin
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net"
7+
"os"
8+
"path"
9+
"path/filepath"
10+
"strings"
11+
"sync"
12+
"time"
13+
14+
"github.com/Azure/azure-container-networking/crd/multitenancy/api/v1alpha1"
15+
"github.com/pkg/errors"
16+
"go.uber.org/zap"
17+
"google.golang.org/grpc"
18+
"google.golang.org/grpc/credentials/insecure"
19+
"k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
20+
)
21+
22+
type Plugin struct {
23+
Logger *zap.Logger
24+
ResourceName string
25+
SocketWatcher *SocketWatcher
26+
Socket string
27+
deviceCountMutex sync.Mutex
28+
deviceCount int
29+
deviceType v1alpha1.DeviceType
30+
kubeletSocket string
31+
deviceCheckInterval time.Duration
32+
devicePluginDirectory string
33+
}
34+
35+
func NewPlugin(l *zap.Logger, resourceName string, socketWatcher *SocketWatcher, pluginDir string,
36+
initialDeviceCount int, deviceType v1alpha1.DeviceType, kubeletSocket string, deviceCheckInterval time.Duration,
37+
) *Plugin {
38+
return &Plugin{
39+
Logger: l.With(zap.String("resourceName", resourceName)),
40+
ResourceName: resourceName,
41+
SocketWatcher: socketWatcher,
42+
Socket: getSocketName(pluginDir, deviceType),
43+
deviceCount: initialDeviceCount,
44+
deviceType: deviceType,
45+
kubeletSocket: kubeletSocket,
46+
deviceCheckInterval: deviceCheckInterval,
47+
devicePluginDirectory: pluginDir,
48+
}
49+
}
50+
51+
// Run runs the plugin until the context is cancelled, restarting the server as needed
52+
func (p *Plugin) Run(ctx context.Context) {
53+
defer p.mustCleanUp()
54+
for {
55+
select {
56+
case <-ctx.Done():
57+
return
58+
default:
59+
p.Logger.Info("starting device plugin for resource", zap.String("resource", p.ResourceName))
60+
if err := p.run(ctx); err != nil {
61+
p.Logger.Error("device plugin for resource exited", zap.String("resource", p.ResourceName), zap.Error(err))
62+
}
63+
}
64+
}
65+
}
66+
67+
// Here we start the gRPC server and wait for it to be ready
68+
// Once the server is ready, device plugin registers with the Kubelet
69+
// so that it can start serving the kubelet requests
70+
func (p *Plugin) run(ctx context.Context) error {
71+
childCtx, cancel := context.WithCancel(ctx)
72+
defer cancel()
73+
74+
s := NewServer(p.Logger, p.Socket, p, p.deviceCheckInterval)
75+
// Run starts the grpc server and blocks until an error or context is cancelled
76+
runErrChan := make(chan error, 2) //nolint:gomnd // disabled in favor of readability
77+
go func(errChan chan error) {
78+
if err := s.Run(childCtx); err != nil {
79+
errChan <- err
80+
}
81+
}(runErrChan)
82+
83+
// Wait till the server is ready before registering with kubelet
84+
// This call is not blocking and returns as soon as the server is ready
85+
readyErrChan := make(chan error, 2) //nolint:gomnd // disabled in favor of readability
86+
go func(errChan chan error) {
87+
errChan <- s.Ready(childCtx)
88+
}(readyErrChan)
89+
90+
select {
91+
case err := <-runErrChan:
92+
return errors.Wrap(err, "error starting grpc server")
93+
case err := <-readyErrChan:
94+
if err != nil {
95+
return errors.Wrap(err, "error waiting on grpc server to be ready")
96+
}
97+
case <-ctx.Done():
98+
return nil
99+
}
100+
101+
p.Logger.Info("registering with kubelet")
102+
// register with kubelet
103+
if err := p.registerWithKubelet(childCtx); err != nil {
104+
return errors.Wrap(err, "failed to register with kubelet")
105+
}
106+
107+
// run until the socket goes away or the context is cancelled
108+
<-p.SocketWatcher.WatchSocket(childCtx, p.Socket)
109+
return nil
110+
}
111+
112+
func (p *Plugin) registerWithKubelet(ctx context.Context) error {
113+
conn, err := grpc.Dial(p.kubeletSocket, grpc.WithTransportCredentials(insecure.NewCredentials()), //nolint:staticcheck // TODO: Move to grpc.NewClient method
114+
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
115+
d := &net.Dialer{}
116+
conn, err := d.DialContext(ctx, "unix", addr)
117+
if err != nil {
118+
return nil, errors.Wrap(err, "failed to dial context")
119+
}
120+
return conn, nil
121+
}))
122+
if err != nil {
123+
return errors.Wrap(err, "error connecting to kubelet")
124+
}
125+
defer conn.Close()
126+
127+
client := v1beta1.NewRegistrationClient(conn)
128+
request := &v1beta1.RegisterRequest{
129+
Version: v1beta1.Version,
130+
Endpoint: filepath.Base(p.Socket),
131+
ResourceName: p.ResourceName,
132+
}
133+
if _, err = client.Register(ctx, request); err != nil {
134+
return errors.Wrap(err, "error sending request to register with kubelet")
135+
}
136+
return nil
137+
}
138+
139+
func (p *Plugin) mustCleanUp() {
140+
p.Logger.Info("cleaning up device plugin")
141+
if err := os.Remove(p.Socket); err != nil && !os.IsNotExist(err) {
142+
p.Logger.Panic("failed to remove socket", zap.Error(err))
143+
}
144+
}
145+
146+
func (p *Plugin) CleanOldState() error {
147+
entries, err := os.ReadDir(p.devicePluginDirectory)
148+
if err != nil {
149+
return errors.Wrap(err, "error listing existing device plugin sockets")
150+
}
151+
for _, entry := range entries {
152+
if strings.HasPrefix(entry.Name(), path.Base(getSocketPrefix(p.devicePluginDirectory, p.deviceType))) {
153+
// try to delete it
154+
f := path.Join(p.devicePluginDirectory, entry.Name())
155+
if err := os.Remove(f); err != nil {
156+
return errors.Wrapf(err, "error removing old socket %q", f)
157+
}
158+
}
159+
}
160+
return nil
161+
}
162+
163+
func (p *Plugin) UpdateDeviceCount(count int) {
164+
p.deviceCountMutex.Lock()
165+
p.deviceCount = count
166+
p.deviceCountMutex.Unlock()
167+
}
168+
169+
func (p *Plugin) getDeviceCount() int {
170+
p.deviceCountMutex.Lock()
171+
defer p.deviceCountMutex.Unlock()
172+
return p.deviceCount
173+
}
174+
175+
// getSocketPrefix returns a fully qualified path prefix for a given device type. For example, if the device plugin directory is
176+
// /home/foo and the device type is acn.azure.com/vnet-nic, this function returns /home/foo/acn.azure.com_vnet-nic
177+
func getSocketPrefix(devicePluginDirectory string, deviceType v1alpha1.DeviceType) string {
178+
sanitizedDeviceName := strings.ReplaceAll(string(deviceType), "/", "_")
179+
return path.Join(devicePluginDirectory, sanitizedDeviceName)
180+
}
181+
182+
func getSocketName(devicePluginDirectory string, deviceType v1alpha1.DeviceType) string {
183+
return fmt.Sprintf("%s-%d.sock", getSocketPrefix(devicePluginDirectory, deviceType), time.Now().Unix())
184+
}

cns/deviceplugin/pluginmanager.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
package deviceplugin
2+
3+
import (
4+
"context"
5+
"sync"
6+
"time"
7+
8+
"github.com/Azure/azure-container-networking/crd/multitenancy/api/v1alpha1"
9+
"github.com/pkg/errors"
10+
"go.uber.org/zap"
11+
"k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
12+
)
13+
14+
const (
15+
defaultDevicePluginDirectory = "/var/lib/kubelet/device-plugins"
16+
defaultDeviceCheckInterval = 5 * time.Second
17+
)
18+
19+
type pluginManagerOptions struct {
20+
devicePluginDirectory string
21+
kubeletSocket string
22+
deviceCheckInterval time.Duration
23+
}
24+
25+
type pluginManagerOption func(*pluginManagerOptions)
26+
27+
func PluginManagerSocketPrefix(prefix string) func(*pluginManagerOptions) {
28+
return func(opts *pluginManagerOptions) {
29+
opts.devicePluginDirectory = prefix
30+
}
31+
}
32+
33+
func PluginManagerKubeletSocket(socket string) func(*pluginManagerOptions) {
34+
return func(opts *pluginManagerOptions) {
35+
opts.kubeletSocket = socket
36+
}
37+
}
38+
39+
func PluginDeviceCheckInterval(i time.Duration) func(*pluginManagerOptions) {
40+
return func(opts *pluginManagerOptions) {
41+
opts.deviceCheckInterval = i
42+
}
43+
}
44+
45+
// PluginManager runs device plugins for vnet nics and ib nics
46+
type PluginManager struct {
47+
Logger *zap.Logger
48+
plugins []*Plugin
49+
socketWatcher *SocketWatcher
50+
options pluginManagerOptions
51+
mu sync.Mutex
52+
}
53+
54+
func NewPluginManager(l *zap.Logger, opts ...pluginManagerOption) *PluginManager {
55+
logger := l.With(zap.String("component", "devicePlugin"))
56+
socketWatcher := NewSocketWatcher(logger)
57+
options := pluginManagerOptions{
58+
devicePluginDirectory: defaultDevicePluginDirectory,
59+
kubeletSocket: v1beta1.KubeletSocket,
60+
deviceCheckInterval: defaultDeviceCheckInterval,
61+
}
62+
for _, o := range opts {
63+
o(&options)
64+
}
65+
return &PluginManager{
66+
Logger: logger,
67+
socketWatcher: socketWatcher,
68+
options: options,
69+
}
70+
}
71+
72+
func (pm *PluginManager) AddPlugin(deviceType v1alpha1.DeviceType, deviceCount int) *PluginManager {
73+
pm.mu.Lock()
74+
defer pm.mu.Unlock()
75+
p := NewPlugin(pm.Logger, string(deviceType), pm.socketWatcher,
76+
pm.options.devicePluginDirectory, deviceCount, deviceType, pm.options.kubeletSocket, pm.options.deviceCheckInterval)
77+
pm.plugins = append(pm.plugins, p)
78+
return pm
79+
}
80+
81+
// Run runs the plugin manager until the context is cancelled or error encountered
82+
func (pm *PluginManager) Run(ctx context.Context) error {
83+
// clean up any leftover state from previous failed plugins
84+
// this can happen if the process crashes before it is able to clean up after itself
85+
for _, plugin := range pm.plugins {
86+
if err := plugin.CleanOldState(); err != nil {
87+
return errors.Wrap(err, "error cleaning state from previous plugin process")
88+
}
89+
}
90+
91+
var wg sync.WaitGroup
92+
for _, plugin := range pm.plugins {
93+
wg.Add(1) //nolint:gomnd // in favor of readability
94+
go func(p *Plugin) {
95+
defer wg.Done()
96+
p.Run(ctx)
97+
}(plugin)
98+
}
99+
100+
wg.Wait()
101+
return nil
102+
}
103+
104+
func (pm *PluginManager) TrackDevices(deviceType v1alpha1.DeviceType, count int) {
105+
pm.mu.Lock()
106+
defer pm.mu.Unlock()
107+
for _, plugin := range pm.plugins {
108+
if plugin.deviceType == deviceType {
109+
plugin.UpdateDeviceCount(count)
110+
break
111+
}
112+
}
113+
}

0 commit comments

Comments
 (0)