Skip to content
Open
30 changes: 25 additions & 5 deletions cmd/nvidia-vgpu-dm/apply/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,22 @@ import (
func VGPUConfig(c *Context) error {
return assert.WalkSelectedVGPUConfigForEachGPU(c.VGPUConfig, func(vc *v1.VGPUConfigSpec, i int, d types.DeviceID) error {
configManager := vgpu.NewNvlibVGPUConfigManager()
current, err := configManager.GetVGPUConfig(i)
IsVFIOEnabled, err := configManager.IsVFIOEnabled()
if err != nil {
return fmt.Errorf("error getting vGPU config: %v", err)
return fmt.Errorf("error checking if Ubuntu 24.04: %v", err)
}

var current types.VGPUConfig
if IsVFIOEnabled {
current, err = configManager.GetVGPUConfigforVFIO(i)
if err != nil {
return fmt.Errorf("error getting VGPU config for VFIO: %v", err)
}
} else {
current, err = configManager.GetVGPUConfig(i)
if err != nil {
return fmt.Errorf("error getting vGPU config: %v", err)
}
}

if current.Equals(vc.VGPUDevices) {
Expand All @@ -40,9 +53,16 @@ func VGPUConfig(c *Context) error {
}

log.Debugf(" Updating vGPU config: %v", vc.VGPUDevices)
err = configManager.SetVGPUConfig(i, vc.VGPUDevices)
if err != nil {
return fmt.Errorf("error setting VGPU config: %v", err)
if IsVFIOEnabled {
err = configManager.SetVGPUConfigforVFIO(i, vc.VGPUDevices)
if err != nil {
return fmt.Errorf("error setting VGPU config for VFIO: %v", err)
}
} else {
err = configManager.SetVGPUConfig(i, vc.VGPUDevices)
if err != nil {
return fmt.Errorf("error setting vGPU config: %v", err)
}
}

return nil
Expand Down
17 changes: 15 additions & 2 deletions cmd/nvidia-vgpu-dm/assert/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,22 @@ func VGPUConfig(c *Context) error {
matched := make([]bool, len(gpus))
err = WalkSelectedVGPUConfigForEachGPU(c.VGPUConfig, func(vc *v1.VGPUConfigSpec, i int, d types.DeviceID) error {
configManager := vgpu.NewNvlibVGPUConfigManager()
current, err := configManager.GetVGPUConfig(i)
IsVFIOEnabled, err := configManager.IsVFIOEnabled()
if err != nil {
return fmt.Errorf("error getting vGPU config: %v", err)
return fmt.Errorf("error checking if VFIO is enabled: %v", err)
}

var current types.VGPUConfig
if IsVFIOEnabled {
current, err = configManager.GetVGPUConfigforVFIO(i)
if err != nil {
return fmt.Errorf("error getting VGPU config for VFIO: %v", err)
}
} else {
current, err = configManager.GetVGPUConfig(i)
if err != nil {
return fmt.Errorf("error getting vGPU config: %v", err)
}
}

log.Debugf(" Asserting vGPU config: %v", vc.VGPUDevices)
Expand Down
165 changes: 165 additions & 0 deletions pkg/vgpu/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ package vgpu
import (
"fmt"
"strings"
"os"
"strconv"
"bufio"
"os/exec"

"github.com/NVIDIA/go-nvlib/pkg/nvmdev"
"github.com/google/uuid"
Expand All @@ -27,11 +31,18 @@ import (
"github.com/NVIDIA/vgpu-device-manager/pkg/types"
)

const (
HostPCIDevicesRoot = "/host/sys/bus/pci/devices"
)

// Manager represents a set of functions for managing vGPU configurations on a node
type Manager interface {
GetVGPUConfig(gpu int) (types.VGPUConfig, error)
SetVGPUConfig(gpu int, config types.VGPUConfig) error
ClearVGPUConfig(gpu int) error
IsVFIOEnabled() (bool, error)
GetVGPUConfigforVFIO(gpu int) (types.VGPUConfig, error)
SetVGPUConfigforVFIO(gpu int, config types.VGPUConfig) error
}

type nvlibVGPUConfigManager struct {
Expand All @@ -45,6 +56,160 @@ func NewNvlibVGPUConfigManager() Manager {
return &nvlibVGPUConfigManager{nvlib.New()}
}

func (m *nvlibVGPUConfigManager) GetVGPUConfigforVFIO(gpu int) (types.VGPUConfig, error) {
nvdevice, err := m.nvlib.Nvpci.GetGPUByIndex(gpu)
if err != nil {
return nil, fmt.Errorf("unable to get GPU by index %d: %v", gpu, err)
}
vgpuConfig := types.VGPUConfig{}
VFnum := 0
if nvdevice.SriovInfo.PhysicalFunction == nil {
return vgpuConfig, nil
}
totalVF := int(nvdevice.SriovInfo.PhysicalFunction.TotalVFs)
for VFnum < totalVF {
VFAddr := HostPCIDevicesRoot + "/" + nvdevice.Address + "/virtfn" + strconv.Itoa(VFnum) + "/nvidia"
if _, err := os.Stat(VFAddr); err != nil {
VFnum++
continue
}
VGPUTypeNumberBytes, err := os.ReadFile(VFAddr + "/current_vgpu_type")
if err != nil {
return nil, fmt.Errorf("unable to read current vGPU type: %v", err)
}
VGPUTypeNumber, err := strconv.Atoi(strings.TrimSpace(string(VGPUTypeNumberBytes)))
if err != nil {
return nil, fmt.Errorf("unable to convert current vGPU type to int: %v", err)
}
if VGPUTypeNumber == 0 {
VFnum++
continue
}
VGPUTypeName, err := m.getVGPUTypeNameforVFIO(VFAddr + "/creatable_vgpu_types", VGPUTypeNumber)
if err != nil {
return nil, fmt.Errorf("unable to get vGPU type name: %v", err)
}
vgpuConfig[VGPUTypeName]++
VFnum++
}
return vgpuConfig, nil
}

//// Set the vGPU config for each GPU if it is in nvdevices
func (m *nvlibVGPUConfigManager) SetVGPUConfigforVFIO(gpu int, config types.VGPUConfig) error {
nvdevice, err := m.nvlib.Nvpci.GetGPUByIndex(gpu)
if err != nil {
return fmt.Errorf("unable to get GPU by index %d: %v", gpu, err)
}

GPUDevices, err := m.nvlib.Nvpci.GetGPUs()
if err != nil {
return fmt.Errorf("unable to get all NVIDIA GPU devices: %v", err)
}

deviceFound := false
for _, device := range GPUDevices {
if device.Address == nvdevice.Address {
deviceFound = true
break
}
}
if !deviceFound {
return fmt.Errorf("GPU at index %d not found in available NVIDIA devices", gpu)
}

cmd := exec.Command("chroot", "/host", "/run/nvidia/driver/usr/lib/nvidia/sriov-manage", "-e", nvdevice.Address)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("unable to execute sriov-manage: %v, output: %s", err, string(output))
}

for key, val := range config {
remainingToCreate := val
VFnum := 0
for remainingToCreate > 0 {
VFAddr := HostPCIDevicesRoot + "/" + nvdevice.Address + "/virtfn" + strconv.Itoa(VFnum) + "/nvidia"
number, err := m.getVGPUTypeNumberforVFIO(VFAddr + "/creatable_vgpu_types", key)
if err != nil {
return fmt.Errorf("unable to get vGPU type number: %v", err)
}
err = os.WriteFile(VFAddr + "/current_vgpu_type", []byte(strconv.Itoa(number)), 0644)
if err != nil {
return fmt.Errorf("unable to write current vGPU type: %v", err)
}
VFnum++
remainingToCreate--
}
}
return nil
}

func (m *nvlibVGPUConfigManager) getVGPUTypeNameforVFIO(filePath string, vgpuTypeNumber int) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", fmt.Errorf("unable to open file %s: %v", filePath, err)
}
defer file.Close()

scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
name := fields[len(fields)-1]
numInt, err := strconv.Atoi(fields[0])
if err == nil && numInt == vgpuTypeNumber {
return name, nil
}
}
return "", fmt.Errorf("vGPU type %d not found in file %s", vgpuTypeNumber, filePath)
}

func (m *nvlibVGPUConfigManager) getVGPUTypeNumberforVFIO(filePath string, vgpuTypeName string) (int, error) {
file, err := os.Open(filePath)
if err != nil {
return 0, fmt.Errorf("unable to open file %s: %v", filePath, err)
}
defer file.Close()

scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
if len(fields) < 2 {
continue
}
name := fields[len(fields)-1]
numInt, err := strconv.Atoi(fields[0])
if err == nil && name == vgpuTypeName {
return numInt, nil
}
}
return 0, fmt.Errorf("vGPU type %s not found in file %s", vgpuTypeName, filePath)
}

func (m *nvlibVGPUConfigManager) IsVFIOEnabled() (bool, error) {
VFIOdistributions := map[string]string{
"ubuntu": "24.04",
"rhel": "10",
}
// Read from the host's /etc/os-release (mounted at /host in the container)
data, err := os.ReadFile("/host/etc/os-release")
if err != nil {
return false, fmt.Errorf("unable to read host OS release info: %v", err)
}

content := string(data)
for distribution, version := range VFIOdistributions {
if strings.Contains(content, distribution) && strings.Contains(content, version) {
return true, nil
}
}
return false, nil
}

// GetVGPUConfig gets the 'VGPUConfig' currently applied to a GPU at a particular index
func (m *nvlibVGPUConfigManager) GetVGPUConfig(gpu int) (types.VGPUConfig, error) {
device, err := m.nvlib.Nvpci.GetGPUByIndex(gpu)
Expand Down