Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions device-plugin/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ module github.com/hyperlight-dev/hyperlight-device-plugin
go 1.25.0

require (
github.com/fsnotify/fsnotify v1.8.0
google.golang.org/grpc v1.78.0
k8s.io/klog/v2 v2.130.1
k8s.io/kubelet v0.35.0
)

require (
github.com/go-logr/logr v1.4.3 // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/sys v0.39.0 // indirect
golang.org/x/text v0.32.0 // indirect
Expand Down
4 changes: 4 additions & 0 deletions device-plugin/go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
Expand Down Expand Up @@ -34,5 +36,7 @@ google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk=
k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE=
k8s.io/kubelet v0.35.0 h1:8cgJHCBCKLYuuQ7/Pxb/qWbJfX1LXIw7790ce9xHq7c=
k8s.io/kubelet v0.35.0/go.mod h1:ciRzAXn7C4z5iB7FhG1L2CGPPXLTVCABDlbXt/Zz8YA=
123 changes: 93 additions & 30 deletions device-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ package main

import (
"context"
"flag"
"fmt"
"log"
"net"
"os"
"os/signal"
Expand All @@ -28,8 +28,10 @@ import (
"syscall"
"time"

"github.com/fsnotify/fsnotify"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"k8s.io/klog/v2"
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
)

Expand Down Expand Up @@ -66,7 +68,7 @@ func NewHyperlightDevicePlugin() (*HyperlightDevicePlugin, error) {
return nil, fmt.Errorf("no supported hypervisor found (/dev/kvm or /dev/mshv)")
}

log.Printf("Detected hypervisor: %s at %s", hypervisor, devicePath)
klog.Infof("Detected hypervisor: %s at %s", hypervisor, devicePath)

// Create CDI spec
if err := writeCDISpec(hypervisor, devicePath); err != nil {
Expand All @@ -84,7 +86,7 @@ func NewHyperlightDevicePlugin() (*HyperlightDevicePlugin, error) {
if count, err := strconv.Atoi(countStr); err == nil && count > 0 {
numDevices = count
} else {
log.Printf("Invalid DEVICE_COUNT '%s', using default %d", countStr, defaultDeviceCount)
klog.Warningf("Invalid DEVICE_COUNT '%s', using default %d", countStr, defaultDeviceCount)
}
}

Expand All @@ -95,7 +97,7 @@ func NewHyperlightDevicePlugin() (*HyperlightDevicePlugin, error) {
Health: pluginapi.Healthy,
}
}
log.Printf("Advertising %d hypervisor devices (configurable via DEVICE_COUNT)", numDevices)
klog.Infof("Advertising %d hypervisor devices (configurable via DEVICE_COUNT)", numDevices)

return &HyperlightDevicePlugin{
devices: devices,
Expand All @@ -113,7 +115,7 @@ func writeCDISpec(hypervisor, devicePath string) error {
if parsed, err := strconv.Atoi(uidStr); err == nil && parsed >= 0 {
uid = parsed
} else {
log.Printf("Invalid DEVICE_UID '%s', using default %d", uidStr, defaultDeviceUID)
klog.Warningf("Invalid DEVICE_UID '%s', using default %d", uidStr, defaultDeviceUID)
}
}

Expand All @@ -122,11 +124,11 @@ func writeCDISpec(hypervisor, devicePath string) error {
if parsed, err := strconv.Atoi(gidStr); err == nil && parsed >= 0 {
gid = parsed
} else {
log.Printf("Invalid DEVICE_GID '%s', using default %d", gidStr, defaultDeviceGID)
klog.Warningf("Invalid DEVICE_GID '%s', using default %d", gidStr, defaultDeviceGID)
}
}

log.Printf("CDI device ownership: uid=%d, gid=%d (configurable via DEVICE_UID/DEVICE_GID)", uid, gid)
klog.Infof("CDI device ownership: uid=%d, gid=%d (configurable via DEVICE_UID/DEVICE_GID)", uid, gid)

spec := fmt.Sprintf(`{
"cdiVersion": "0.6.0",
Expand Down Expand Up @@ -159,7 +161,7 @@ func writeCDISpec(hypervisor, devicePath string) error {
if err := os.WriteFile(cdiSpecPath, []byte(spec), 0644); err != nil {
return err
}
log.Printf("CDI spec written to %s", cdiSpecPath)
klog.Infof("CDI spec written to %s", cdiSpecPath)
return nil
}

Expand All @@ -173,7 +175,7 @@ func (p *HyperlightDevicePlugin) GetDevicePluginOptions(ctx context.Context, req

// ListAndWatch lists devices and watches for changes
func (p *HyperlightDevicePlugin) ListAndWatch(req *pluginapi.Empty, srv pluginapi.DevicePlugin_ListAndWatchServer) error {
log.Printf("ListAndWatch called, sending %d devices", len(p.devices))
klog.Infof("ListAndWatch called, sending %d devices", len(p.devices))

if err := srv.Send(&pluginapi.ListAndWatchResponse{Devices: p.devices}); err != nil {
return err
Expand All @@ -191,12 +193,16 @@ func (p *HyperlightDevicePlugin) ListAndWatch(req *pluginapi.Empty, srv pluginap
health := pluginapi.Healthy
if _, err := os.Stat(p.devicePath); err != nil {
health = pluginapi.Unhealthy
log.Printf("Device %s not found, marking unhealthy", p.devicePath)
klog.Warningf("Device %s not found, marking all devices unhealthy", p.devicePath)
}

// Check if health changed (compare against first device as representative)
if p.devices[0].Health != health {
p.devices[0].Health = health
log.Printf("Device health changed to %s", health)
// Update ALL devices - they all share the same underlying hypervisor device
for i := range p.devices {
p.devices[i].Health = health
}
klog.Infof("Device health changed to %s for all %d devices", health, len(p.devices))
if err := srv.Send(&pluginapi.ListAndWatchResponse{Devices: p.devices}); err != nil {
return err
}
Expand All @@ -207,7 +213,7 @@ func (p *HyperlightDevicePlugin) ListAndWatch(req *pluginapi.Empty, srv pluginap

// Allocate allocates devices to a container
func (p *HyperlightDevicePlugin) Allocate(ctx context.Context, req *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) {
log.Printf("Allocate called for %d containers", len(req.ContainerRequests))
klog.V(2).Infof("Allocate called for %d containers", len(req.ContainerRequests))

responses := make([]*pluginapi.ContainerAllocateResponse, len(req.ContainerRequests))

Expand All @@ -220,7 +226,7 @@ func (p *HyperlightDevicePlugin) Allocate(ctx context.Context, req *pluginapi.Al
},
},
}
log.Printf("Allocated CDI device: hyperlight.dev/hypervisor=%s", p.hypervisor)
klog.V(2).Infof("Allocated CDI device: hyperlight.dev/hypervisor=%s", p.hypervisor)
}

return &pluginapi.AllocateResponse{ContainerResponses: responses}, nil
Expand All @@ -242,7 +248,7 @@ func (p *HyperlightDevicePlugin) Start() error {

// Remove old socket
if err := os.Remove(serverSock); err != nil && !os.IsNotExist(err) {
log.Printf("Warning: failed to remove old socket: %v", err)
klog.Warningf("Failed to remove old socket: %v", err)
}

listener, err := net.Listen("unix", serverSock)
Expand All @@ -254,9 +260,9 @@ func (p *HyperlightDevicePlugin) Start() error {
pluginapi.RegisterDevicePluginServer(p.server, p)

go func() {
log.Printf("Starting gRPC server on %s", serverSock)
klog.Infof("Starting gRPC server on %s", serverSock)
if err := p.server.Serve(listener); err != nil {
log.Printf("gRPC server stopped: %v", err)
klog.V(1).Infof("gRPC server stopped: %v", err)
}
}()

Expand Down Expand Up @@ -297,7 +303,7 @@ func (p *HyperlightDevicePlugin) Register() error {
return fmt.Errorf("failed to register with kubelet: %v", err)
}

log.Printf("Registered with kubelet as %s", resourceName)
klog.Infof("Registered with kubelet as %s", resourceName)
return nil
}

Expand All @@ -307,14 +313,69 @@ func (p *HyperlightDevicePlugin) Stop() {
p.server.Stop()
}
os.Remove(serverSock)
log.Println("Device plugin stopped")
klog.Info("Device plugin stopped")
}

// newFSWatcher creates a filesystem watcher for kubelet restart detection.
func newFSWatcher(files ...string) (*fsnotify.Watcher, error) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}

for _, f := range files {
if err := watcher.Add(f); err != nil {
watcher.Close()
return nil, err
}
}

return watcher, nil
}

// watchKubeletRestart monitors for kubelet restarts by watching the plugin socket.
// watchKubeletRestart monitors for kubelet restarts using fsnotify.
// When kubelet restarts, it deletes all sockets in /var/lib/kubelet/device-plugins/.
// This function blocks until the socket is deleted, signaling a kubelet restart.
// This function blocks until it detects our plugin socket being deleted.
func (p *HyperlightDevicePlugin) watchKubeletRestart() {
log.Println("Watching for kubelet restart (socket deletion)...")
klog.Info("Watching for kubelet restart using fsnotify...")

watcher, err := newFSWatcher(pluginapi.DevicePluginPath)
if err != nil {
klog.Errorf("Failed to create fsnotify watcher, falling back to polling: %v", err)
p.watchKubeletRestartPolling()
return
}
defer watcher.Close()

for {
select {
case <-p.stopCh:
return
case event, ok := <-watcher.Events:
if !ok {
klog.Warning("fsnotify events channel closed, falling back to polling")
p.watchKubeletRestartPolling()
return
}
if event.Name == serverSock && (event.Op&fsnotify.Remove) == fsnotify.Remove {
klog.Info("Plugin socket deleted - kubelet may have restarted")
return
}
case err, ok := <-watcher.Errors:
if !ok {
klog.Warning("fsnotify errors channel closed, falling back to polling")
p.watchKubeletRestartPolling()
return
}
klog.Warningf("fsnotify error: %v", err)
}
}
}

// watchKubeletRestartPolling is a fallback method using polling.
// Used when fsnotify is unavailable.
func (p *HyperlightDevicePlugin) watchKubeletRestartPolling() {
klog.Info("Watching for kubelet restart (polling)...")

ticker := time.NewTicker(time.Second)
defer ticker.Stop()
Expand All @@ -324,22 +385,24 @@ func (p *HyperlightDevicePlugin) watchKubeletRestart() {
case <-p.stopCh:
return
case <-ticker.C:
// Check if our socket still exists
if _, err := os.Stat(serverSock); os.IsNotExist(err) {
log.Println("Plugin socket deleted - kubelet may have restarted")
klog.Info("Plugin socket deleted - kubelet may have restarted")
return
}
}
}
}

func main() {
log.SetFlags(log.LstdFlags | log.Lshortfile)
log.Println("Starting Hyperlight Device Plugin")
klog.InitFlags(nil)
flag.Parse()
defer klog.Flush()

klog.Info("Starting Hyperlight Device Plugin")

plugin, err := NewHyperlightDevicePlugin()
if err != nil {
log.Fatalf("Failed to create device plugin: %v", err)
klog.Fatalf("Failed to create device plugin: %v", err)
}

// Handle signals for graceful shutdown
Expand All @@ -350,7 +413,7 @@ func main() {
go func() {
for {
if err := plugin.Start(); err != nil {
log.Printf("Failed to start device plugin: %v", err)
klog.Errorf("Failed to start device plugin: %v", err)
time.Sleep(5 * time.Second)
continue
}
Expand All @@ -360,13 +423,13 @@ func main() {
plugin.watchKubeletRestart()

// If we get here, kubelet restarted - stop current server and re-register
log.Println("Detected kubelet restart, re-registering...")
klog.Info("Detected kubelet restart, re-registering...")
plugin.server.Stop()
time.Sleep(time.Second) // Brief pause before restart
}
}()

sig := <-sigCh
log.Printf("Received signal %v, shutting down", sig)
klog.Infof("Received signal %v, shutting down", sig)
plugin.Stop()
}