Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 91 additions & 12 deletions cmd/driver-manager/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ import (
)

const (
driverRoot = "/run/nvidia/driver"
driverPIDFile = "/run/nvidia/nvidia-driver.pid"
operatorNamespace = "gpu-operator"
pausedStr = "paused-for-driver-upgrade"
defaultDrainTimeout = time.Second * 0
defaultGracePeriod = 5 * time.Minute
driverRoot = "/run/nvidia/driver"
driverPIDFile = "/run/nvidia/nvidia-driver.pid"
driverConfigStateFile = "/run/nvidia/nvidia-driver.state"
operatorNamespace = "gpu-operator"
pausedStr = "paused-for-driver-upgrade"
defaultDrainTimeout = time.Second * 0
defaultGracePeriod = 5 * time.Minute

nvidiaDomainPrefix = "nvidia.com"

Expand Down Expand Up @@ -77,6 +78,7 @@ type config struct {
gpuDirectRDMAEnabled bool
useHostMofed bool
kubeconfig string
forceReinstall bool
}

// ComponentState tracks the deployment state of GPU operator components
Expand Down Expand Up @@ -208,6 +210,13 @@ func main() {
EnvVars: []string{"KUBECONFIG"},
Value: "",
},
&cli.BoolFlag{
Name: "force-reinstall",
Usage: "Force driver reinstall regardless of current state",
Destination: &cfg.forceReinstall,
EnvVars: []string{"FORCE_REINSTALL"},
Value: false,
},
}

app.Commands = []*cli.Command{
Expand Down Expand Up @@ -288,6 +297,26 @@ func (dm *DriverManager) uninstallDriver() error {
return fmt.Errorf("failed to evict GPU operator components: %w", err)
}

if dm.shouldSkipUninstall() {
dm.log.Info("The NVIDIA driver is already loaded with the desired version and configuration, skipping the uninstallation of the driver in an attempt to not disrupt running workloads")

// Clean up stale artifacts from previous container before rescheduling operands
dm.log.Info("Cleaning up stale mounts and state files...")

// Unmount stale rootfs from previous container
if err := dm.unmountRootfs(); err != nil {
return fmt.Errorf("failed to unmount stale rootfs: %w", err)
}

// Remove stale PID file from previous container
dm.removePIDFile()

if err := dm.rescheduleGPUOperatorComponents(); err != nil {
dm.log.Warnf("Failed to reschedule GPU operator components: %v", err)
}
return nil
}

drainOpts := kube.DrainOptions{
Force: dm.config.drainUseForce,
DeleteEmptyDirData: dm.config.drainDeleteEmptyDirData,
Expand Down Expand Up @@ -629,6 +658,55 @@ func (dm *DriverManager) isDriverLoaded() bool {
return err == nil
}

// readStoredDigest reads the driver configuration digest from the state file
func readStoredDigest() (string, error) {
data, err := os.ReadFile(driverConfigStateFile)
if err != nil {
return "", err
}
return strings.TrimSpace(string(data)), nil
}

// shouldUpdateDriverConfig checks if the driver configuration needs to be updated
func (dm *DriverManager) shouldUpdateDriverConfig() bool {
if !dm.isDriverLoaded() {
return true
}

dm.log.Info("Checking if the currently loaded NVIDIA driver version and configuration matches the desired state...")

currentDigest := os.Getenv("DRIVER_CONFIG_DIGEST")
if currentDigest == "" {
dm.log.Warn("DRIVER_CONFIG_DIGEST env var not set, assuming config changed")
return true
}

storedDigest, err := readStoredDigest()
if err != nil {
if os.IsNotExist(err) {
dm.log.Info("No previous driver configuration found")
} else {
dm.log.Warnf("Failed to read driver config state file: %v", err)
}
return true
}

return currentDigest != storedDigest
}

func (dm *DriverManager) shouldSkipUninstall() bool {
if dm.config.forceReinstall {
dm.log.Info("Force reinstall is enabled, proceeding with driver uninstall")
return false
}

if !dm.shouldUpdateDriverConfig() {
return true
}

return false
}

func (dm *DriverManager) isNouveauLoaded() bool {
_, err := os.Stat("/sys/module/nouveau/refcnt")
return err == nil
Expand All @@ -639,6 +717,12 @@ func (dm *DriverManager) unloadNouveau() error {
return unix.DeleteModule("nouveau", 0)
}

func (dm *DriverManager) removePIDFile() {
if err := os.Remove(driverPIDFile); err != nil && !os.IsNotExist(err) {
dm.log.Warnf("Failed to remove PID file %s: %v", driverPIDFile, err)
}
}

func (dm *DriverManager) cleanupDriver() error {
dm.log.Info("Cleaning up NVIDIA driver")

Expand All @@ -652,12 +736,7 @@ func (dm *DriverManager) cleanupDriver() error {
return fmt.Errorf("failed to unmount rootfs: %w", err)
}

// Remove PID file
if _, err := os.Stat(driverPIDFile); err == nil {
if err := os.Remove(driverPIDFile); err != nil {
dm.log.Warnf("Failed to remove PID file %s: %v", driverPIDFile, err)
}
}
dm.removePIDFile()

return nil
}
Expand Down