diff --git a/device-plugin/main.go b/device-plugin/main.go index 8b0d473..9c0a5c3 100644 --- a/device-plugin/main.go +++ b/device-plugin/main.go @@ -237,6 +237,9 @@ func (p *HyperlightDevicePlugin) GetPreferredAllocation(ctx context.Context, req } func (p *HyperlightDevicePlugin) Start() error { + // Reset stop channel for restart scenarios + p.stopCh = make(chan struct{}) + // Remove old socket if err := os.Remove(serverSock); err != nil && !os.IsNotExist(err) { log.Printf("Warning: failed to remove old socket: %v", err) @@ -307,6 +310,29 @@ func (p *HyperlightDevicePlugin) Stop() { log.Println("Device plugin stopped") } +// watchKubeletRestart monitors for kubelet restarts by watching the plugin socket. +// When kubelet restarts, it deletes all sockets in /var/lib/kubelet/device-plugins/. +// This function blocks until the socket is deleted, signaling a kubelet restart. +func (p *HyperlightDevicePlugin) watchKubeletRestart() { + log.Println("Watching for kubelet restart (socket deletion)...") + + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case <-p.stopCh: + return + case <-ticker.C: + // Check if our socket still exists + if _, err := os.Stat(serverSock); os.IsNotExist(err) { + log.Println("Plugin socket deleted - kubelet may have restarted") + return + } + } + } +} + func main() { log.SetFlags(log.LstdFlags | log.Lshortfile) log.Println("Starting Hyperlight Device Plugin") @@ -316,14 +342,30 @@ func main() { log.Fatalf("Failed to create device plugin: %v", err) } - if err := plugin.Start(); err != nil { - log.Fatalf("Failed to start device plugin: %v", err) - } - // Handle signals for graceful shutdown sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + // Start plugin with restart handling + go func() { + for { + if err := plugin.Start(); err != nil { + log.Printf("Failed to start device plugin: %v", err) + time.Sleep(5 * time.Second) + continue + } + + // Watch for kubelet restart (socket deletion) + // When kubelet restarts, it deletes all sockets in /var/lib/kubelet/device-plugins/ + plugin.watchKubeletRestart() + + // If we get here, kubelet restarted - stop current server and re-register + log.Println("Detected kubelet restart, re-registering...") + plugin.server.Stop() + time.Sleep(time.Second) // Brief pause before restart + } + }() + sig := <-sigCh log.Printf("Received signal %v, shutting down", sig) plugin.Stop()