Skip to content

Commit 9ad1573

Browse files
feat: detect config changes (version, kernel, module params, RDMA) to trigger driver reinstall
Signed-off-by: Karthik Vetrivel <[email protected]>
1 parent a540c4f commit 9ad1573

File tree

1 file changed

+133
-59
lines changed

1 file changed

+133
-59
lines changed

cmd/driver-manager/main.go

Lines changed: 133 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package main
2020

2121
import (
22+
"bytes"
2223
"context"
2324
"errors"
2425
"fmt"
@@ -40,15 +41,30 @@ import (
4041
)
4142

4243
const (
43-
driverRoot = "/run/nvidia/driver"
44-
driverPIDFile = "/run/nvidia/nvidia-driver.pid"
45-
operatorNamespace = "gpu-operator"
46-
pausedStr = "paused-for-driver-upgrade"
47-
defaultDrainTimeout = time.Second * 0
48-
defaultGracePeriod = 5 * time.Minute
44+
driverRoot = "/run/nvidia/driver"
45+
driverPIDFile = "/run/nvidia/nvidia-driver.pid"
46+
driverConfigStateFile = "/run/nvidia/driver-config.state"
47+
operatorNamespace = "gpu-operator"
48+
pausedStr = "paused-for-driver-upgrade"
49+
defaultDrainTimeout = time.Second * 0
50+
defaultGracePeriod = 5 * time.Minute
4951

5052
nvidiaDomainPrefix = "nvidia.com"
5153

54+
nvidiaModuleConfigFile = "/drivers/nvidia.conf"
55+
nvidiaUVMModuleConfigFile = "/drivers/nvidia-uvm.conf"
56+
nvidiaModsetModuleConfigFile = "/drivers/nvidia-modeset.conf"
57+
nvidiaPeermemModuleConfigFile = "/drivers/nvidia-peermem.conf"
58+
)
59+
60+
var (
61+
driverConfigFiles = []string{
62+
nvidiaModuleConfigFile,
63+
nvidiaUVMModuleConfigFile,
64+
nvidiaModsetModuleConfigFile,
65+
nvidiaPeermemModuleConfigFile,
66+
}
67+
5268
nvidiaDriverDeployLabel = nvidiaDomainPrefix + "/" + "gpu.deploy.driver"
5369
nvidiaOperatorValidatorDeployLabel = nvidiaDomainPrefix + "/" + "gpu.deploy.operator-validator"
5470
nvidiaContainerToolkitDeployLabel = nvidiaDomainPrefix + "/" + "gpu.deploy.container-toolkit"
@@ -304,8 +320,20 @@ func (dm *DriverManager) uninstallDriver() error {
304320
return fmt.Errorf("failed to evict GPU operator components: %w", err)
305321
}
306322

307-
if skip, reason := dm.shouldSkipUninstall(); skip {
308-
dm.log.Infof("Skipping driver uninstall: %s", reason)
323+
if dm.shouldSkipUninstall() {
324+
dm.log.Info("Fast path activated: desired driver version and configuration already present")
325+
326+
// Clean up stale artifacts from previous container before rescheduling operands
327+
dm.log.Info("Cleaning up stale mounts and state files...")
328+
329+
// Unmount stale rootfs from previous container
330+
if err := dm.unmountRootfs(); err != nil {
331+
return fmt.Errorf("failed to unmount stale rootfs: %w", err)
332+
}
333+
334+
// Remove stale PID file from previous container
335+
dm.removePIDFile()
336+
309337
if err := dm.rescheduleGPUOperatorComponents(); err != nil {
310338
dm.log.Warnf("Failed to reschedule GPU operator components: %v", err)
311339
}
@@ -653,68 +681,113 @@ func (dm *DriverManager) isDriverLoaded() bool {
653681
return err == nil
654682
}
655683

656-
func (dm *DriverManager) shouldSkipUninstall() (bool, string) {
657-
if dm.config.forceReinstall {
658-
dm.log.Info("Force reinstall is enabled, proceeding with driver uninstall")
659-
return false, ""
684+
// getValueWithOverride extracts a value from config by key, but returns override if non-empty
685+
func getValueWithOverride(config, key, override string) string {
686+
if override != "" {
687+
return override
660688
}
661-
662-
if !dm.isDriverLoaded() {
663-
return false, ""
689+
for _, line := range strings.Split(config, "\n") {
690+
if strings.HasPrefix(line, key+"=") {
691+
return strings.TrimPrefix(line, key+"=")
692+
}
664693
}
694+
return ""
695+
}
665696

666-
if dm.config.driverVersion == "" {
667-
return false, "Driver version environment variable is not set"
697+
// getKernelVersion returns the current kernel version
698+
func getKernelVersion() string {
699+
var utsname unix.Utsname
700+
if err := unix.Uname(&utsname); err != nil {
701+
return ""
668702
}
669703

670-
version, err := dm.detectCurrentDriverVersion()
671-
if err != nil {
672-
dm.log.Warnf("Unable to determine installed driver version: %v", err)
673-
// If driver is loaded but we can't detect version, proceed with reinstall to ensure correct version
674-
dm.log.Info("Cannot verify driver version, proceeding with reinstall to ensure correct version is installed")
675-
return false, ""
676-
}
704+
release := utsname.Release[:]
705+
nullIdx := bytes.IndexByte(release, 0)
706+
return string(release[:nullIdx])
707+
}
677708

678-
if version != dm.config.driverVersion {
679-
dm.log.Infof("Installed driver version %s does not match desired %s, proceeding with uninstall", version, dm.config.driverVersion)
680-
return false, ""
709+
// buildCurrentConfig constructs the current driver configuration string
710+
func (dm *DriverManager) buildCurrentConfig(storedConfig string) string {
711+
driverVersion := getValueWithOverride(storedConfig, "DRIVER_VERSION", dm.config.driverVersion)
712+
kernelVersion := getValueWithOverride(storedConfig, "KERNEL_VERSION", getKernelVersion())
713+
kernelModuleType := getValueWithOverride(storedConfig, "KERNEL_MODULE_TYPE", os.Getenv("KERNEL_MODULE_TYPE"))
714+
driverTypeEnv := os.Getenv("DRIVER_TYPE")
715+
if driverTypeEnv == "" {
716+
driverTypeEnv = "passthrough"
717+
}
718+
driverType := getValueWithOverride(storedConfig, "DRIVER_TYPE", driverTypeEnv)
719+
720+
// Read module parameters from conf files
721+
nvidiaParams := readModuleParams(nvidiaModuleConfigFile)
722+
nvidiaUVMParams := readModuleParams(nvidiaUVMModuleConfigFile)
723+
nvidiaModeset := readModuleParams(nvidiaModsetModuleConfigFile)
724+
nvidiaPeermem := readModuleParams(nvidiaPeermemModuleConfigFile)
725+
726+
var config strings.Builder
727+
config.WriteString(fmt.Sprintf("DRIVER_VERSION=%s\n", driverVersion))
728+
config.WriteString(fmt.Sprintf("DRIVER_TYPE=%s\n", driverType))
729+
config.WriteString(fmt.Sprintf("KERNEL_VERSION=%s\n", kernelVersion))
730+
config.WriteString(fmt.Sprintf("GPU_DIRECT_RDMA_ENABLED=%v\n", dm.config.gpuDirectRDMAEnabled))
731+
config.WriteString(fmt.Sprintf("USE_HOST_MOFED=%v\n", dm.config.useHostMofed))
732+
config.WriteString(fmt.Sprintf("KERNEL_MODULE_TYPE=%s\n", kernelModuleType))
733+
config.WriteString(fmt.Sprintf("NVIDIA_MODULE_PARAMS=%s\n", nvidiaParams))
734+
config.WriteString(fmt.Sprintf("NVIDIA_UVM_MODULE_PARAMS=%s\n", nvidiaUVMParams))
735+
config.WriteString(fmt.Sprintf("NVIDIA_MODESET_MODULE_PARAMS=%s\n", nvidiaModeset))
736+
config.WriteString(fmt.Sprintf("NVIDIA_PEERMEM_MODULE_PARAMS=%s\n", nvidiaPeermem))
737+
738+
// Append config file contents directly
739+
for _, file := range driverConfigFiles {
740+
if data, err := os.ReadFile(file); err == nil && len(data) > 0 {
741+
config.Write(data)
742+
}
681743
}
682744

683-
dm.log.Infof("Installed driver version %s matches desired version, skipping uninstall", version)
684-
return true, "desired version already present"
745+
return config.String()
685746
}
686747

687-
func (dm *DriverManager) detectCurrentDriverVersion() (string, error) {
688-
baseCtx := dm.ctx
689-
if baseCtx == nil {
690-
baseCtx = context.Background()
748+
// readModuleParams reads a module parameter config file and returns its contents as a single-line space-separated string
749+
func readModuleParams(filepath string) string {
750+
data, err := os.ReadFile(filepath)
751+
if err != nil {
752+
return ""
691753
}
754+
// Convert newlines to spaces to match bash implementation
755+
return strings.ReplaceAll(strings.TrimSpace(string(data)), "\n", " ")
756+
}
692757

693-
ctx, cancel := context.WithTimeout(baseCtx, 10*time.Second)
694-
defer cancel()
695-
696-
// Try chroot to /run/nvidia/driver for containerized driver
697-
cmd := exec.CommandContext(ctx, "chroot", "/run/nvidia/driver", "modinfo", "-F", "version", "nvidia")
698-
cmd.Env = append(os.Environ(), "LC_ALL=C")
699-
cmdOutput, chrootErr := cmd.Output()
700-
if chrootErr == nil {
701-
version := strings.TrimSpace(string(cmdOutput))
702-
if version != "" {
703-
dm.log.Infof("Driver version detected via chroot: %s", version)
704-
return version, nil
758+
// driverModuleBuildNeeded checks if driver modules need to be rebuilt
759+
func (dm *DriverManager) driverModuleBuildNeeded() bool {
760+
storedData, err := os.ReadFile(driverConfigStateFile)
761+
if err != nil {
762+
if os.IsNotExist(err) {
763+
dm.log.Info("No previous driver configuration found")
764+
return true
705765
}
766+
dm.log.Warnf("Failed to read driver config state file: %v", err)
767+
return true
706768
}
707769

708-
// Second try to read from /sys/module/nvidia/version if available
709-
if versionData, err := os.ReadFile("/sys/module/nvidia/version"); err == nil {
710-
version := strings.TrimSpace(string(versionData))
711-
if version != "" {
712-
dm.log.Infof("Driver version detected from /sys/module/nvidia/version: %s", version)
713-
return version, nil
714-
}
770+
storedConfig := string(storedData)
771+
currentConfig := dm.buildCurrentConfig(storedConfig)
772+
773+
return currentConfig != storedConfig
774+
}
775+
776+
func (dm *DriverManager) shouldSkipUninstall() bool {
777+
if dm.config.forceReinstall {
778+
dm.log.Info("Force reinstall is enabled, proceeding with driver uninstall")
779+
return false
715780
}
716781

717-
return "", fmt.Errorf("all version detection methods failed: chroot: %v", chrootErr)
782+
// Only skip uninstall if driver IS loaded AND config matches (fast path optimization)
783+
if dm.isDriverLoaded() && !dm.driverModuleBuildNeeded() {
784+
dm.log.Info("Driver is loaded with matching config, enabling fast path")
785+
return true
786+
}
787+
788+
// Driver not loaded or config changed - proceed with cleanup
789+
dm.log.Info("Proceeding with cleanup operations")
790+
return false
718791
}
719792

720793
func (dm *DriverManager) isNouveauLoaded() bool {
@@ -727,6 +800,12 @@ func (dm *DriverManager) unloadNouveau() error {
727800
return unix.DeleteModule("nouveau", 0)
728801
}
729802

803+
func (dm *DriverManager) removePIDFile() {
804+
if err := os.Remove(driverPIDFile); err != nil && !os.IsNotExist(err) {
805+
dm.log.Warnf("Failed to remove PID file %s: %v", driverPIDFile, err)
806+
}
807+
}
808+
730809
func (dm *DriverManager) cleanupDriver() error {
731810
dm.log.Info("Cleaning up NVIDIA driver")
732811

@@ -740,12 +819,7 @@ func (dm *DriverManager) cleanupDriver() error {
740819
return fmt.Errorf("failed to unmount rootfs: %w", err)
741820
}
742821

743-
// Remove PID file
744-
if _, err := os.Stat(driverPIDFile); err == nil {
745-
if err := os.Remove(driverPIDFile); err != nil {
746-
dm.log.Warnf("Failed to remove PID file %s: %v", driverPIDFile, err)
747-
}
748-
}
822+
dm.removePIDFile()
749823

750824
return nil
751825
}

0 commit comments

Comments
 (0)