Skip to content

Commit d134551

Browse files
committed
updated vfio to represent virtual function
1 parent 55e063c commit d134551

File tree

3 files changed

+49
-58
lines changed

3 files changed

+49
-58
lines changed

internal/vfio/vfio.go

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func NewVFIOManager(nvlibInstance nvlib.Interface) *VFIOManager {
2828
// ParentDevice represents an NVIDIA parent PCI device.
2929
type ParentDevice struct {
3030
*nvpci.NvidiaPCIDevice
31-
VirtualFunctionPaths map[string]string
31+
VirtualFunctionPath string
3232
}
3333

3434
// Device represents an NVIDIA (vGPU) device.
@@ -46,19 +46,17 @@ func (m *VFIOManager) GetAllParentDevices() ([]*ParentDevice, error) {
4646
for _, device := range nvdevices {
4747
vfnum := 0
4848
numVF := int(device.SriovInfo.PhysicalFunction.NumVFs)
49-
virtualFunctionPaths := make(map[string]string)
5049
for vfnum < numVF {
5150
vfAddr := filepath.Join(HostPCIDevicesRoot, device.Address, "virtfn"+strconv.Itoa(vfnum), "nvidia")
5251
if _, err := os.Stat(vfAddr); err != nil {
5352
return nil, fmt.Errorf("virtual function %d at address %s does not exist", vfnum, vfAddr)
5453
}
55-
virtualFunctionPaths[strconv.Itoa(vfnum)] = vfAddr
54+
parentDevices = append(parentDevices, &ParentDevice{
55+
NvidiaPCIDevice: device,
56+
VirtualFunctionPath: vfAddr,
57+
})
5658
vfnum++
5759
}
58-
parentDevices = append(parentDevices, &ParentDevice{
59-
NvidiaPCIDevice: device,
60-
VirtualFunctionPaths: virtualFunctionPaths,
61-
})
6260
}
6361
return parentDevices, nil
6462
}
@@ -70,8 +68,7 @@ func (m *VFIOManager) GetAllDevices() ([]*Device, error) {
7068
}
7169
devices := []*Device{}
7270
for _, parentDevice := range parentDevices {
73-
for _, vfAddr := range parentDevice.VirtualFunctionPaths {
74-
vgpuTypeNumberBytes, err := os.ReadFile(filepath.Join(vfAddr, "current_vgpu_type"))
71+
vgpuTypeNumberBytes, err := os.ReadFile(filepath.Join(parentDevice.VirtualFunctionPath, "current_vgpu_type"))
7572
if err != nil {
7673
return nil, fmt.Errorf("unable to read current vGPU type: %v", err)
7774
}
@@ -81,19 +78,18 @@ func (m *VFIOManager) GetAllDevices() ([]*Device, error) {
8178
}
8279
if vgpuTypeNumber != 0 {
8380
devices = append(devices, &Device{
84-
Path: vfAddr,
81+
Path: parentDevice.VirtualFunctionPath,
8582
Parent: parentDevice,
8683
})
8784
}
88-
}
8985
}
9086
return devices, nil
9187
}
9288

9389
// GetPhysicalFunction gets the physical PCI device backing a 'parent' device.
9490
func (p *ParentDevice) GetPhysicalFunction() *nvpci.NvidiaPCIDevice {
95-
if p.SriovInfo.IsVF() {
96-
return p.SriovInfo.VirtualFunction.PhysicalFunction
91+
if p.NvidiaPCIDevice.SriovInfo.IsVF() {
92+
return p.NvidiaPCIDevice.SriovInfo.VirtualFunction.PhysicalFunction
9793
}
9894
// Either it is an SRIOV physical function or a non-SRIOV device, so return the device itself
9995
return p.NvidiaPCIDevice
@@ -149,24 +145,22 @@ func (m *VFIOManager) IsVFIOEnabled(gpu int) (bool, error) {
149145

150146
// IsVGPUTypeSupported checks if the vfioType is supported by this parent GPU
151147
func (p *ParentDevice) IsVGPUTypeAvailable(vfioType string) (bool, error) {
152-
for _, vfPath := range p.VirtualFunctionPaths {
153-
creatableTypesPath := filepath.Join(vfPath, "creatable_vgpu_types")
154-
file, err := os.Open(creatableTypesPath)
155-
if err != nil {
156-
return false, fmt.Errorf("unable to open file %s: %v", creatableTypesPath, err)
148+
creatableTypesPath := filepath.Join(p.VirtualFunctionPath, "creatable_vgpu_types")
149+
file, err := os.Open(creatableTypesPath)
150+
if err != nil {
151+
return false, fmt.Errorf("unable to open file %s: %v", creatableTypesPath, err)
152+
}
153+
defer file.Close()
154+
scanner := bufio.NewScanner(file)
155+
for scanner.Scan() {
156+
line := scanner.Text()
157+
fields := strings.Fields(line)
158+
if len(fields) < 2 {
159+
continue
157160
}
158-
defer file.Close()
159-
scanner := bufio.NewScanner(file)
160-
for scanner.Scan() {
161-
line := scanner.Text()
162-
fields := strings.Fields(line)
163-
if len(fields) < 2 {
164-
continue
165-
}
166-
name := fields[len(fields)-1]
167-
if name == vfioType {
168-
return true, nil
169-
}
161+
name := fields[len(fields)-1]
162+
if name == vfioType {
163+
return true, nil
170164
}
171165
}
172166
return false, nil
@@ -183,7 +177,7 @@ func (m *Device) Delete() error {
183177
}
184178

185179
func (p *ParentDevice) CreateVGPUDevice(vfioType string, vfnum string) error {
186-
vfPath := p.VirtualFunctionPaths[vfnum]
180+
vfPath := p.VirtualFunctionPath
187181
currentVGPUTypePath := filepath.Join(vfPath, "current_vgpu_type")
188182
number, err := p.GetIdForVGPUTypeName(filepath.Join(vfPath, "creatable_vgpu_types"), vfioType)
189183
if err != nil {
@@ -202,7 +196,7 @@ func (p *ParentDevice) GetAvailableVGPUInstances(vfioType string) (int, error) {
202196
return 0, fmt.Errorf("unable to check if vGPU type is available: %v", err)
203197
}
204198
if available {
205-
return int(p.NvidiaPCIDevice.SriovInfo.PhysicalFunction.NumVFs), nil
199+
return 1, nil
206200
}
207201
return 0, nil
208202
}

internal/vgpu-combined/vgpu-combined.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ package vgpu_combined
22

33
import (
44
"fmt"
5+
"strconv"
56

67
"github.com/NVIDIA/go-nvlib/pkg/nvmdev"
78
"github.com/NVIDIA/go-nvlib/pkg/nvpci"
89
"github.com/NVIDIA/vgpu-device-manager/internal/nvlib"
910
"github.com/NVIDIA/vgpu-device-manager/internal/vfio"
11+
"github.com/google/uuid"
1012
)
1113

1214
type VGPUCombinedManager struct {
@@ -147,11 +149,11 @@ func (m *VGPUCombinedManager) GetAllDevices() ([]DeviceInterface, error) {
147149
}
148150

149151
func (m *VGPUCombinedManager) CreateVGPUDevices(device *nvpci.NvidiaPCIDevice, vgpuType string, count int) error {
150-
remainingToCreate := count
151152
parents, err := m.GetAllParentDevicesbyAddress(device.Address)
152153
if err != nil {
153154
return fmt.Errorf("error getting all parent devices by address: %v", err)
154155
}
156+
remainingToCreate := count
155157
for _, parent := range parents {
156158
if remainingToCreate == 0 {
157159
break
@@ -166,10 +168,10 @@ func (m *VGPUCombinedManager) CreateVGPUDevices(device *nvpci.NvidiaPCIDevice, v
166168

167169
numToCreate := min(remainingToCreate, available)
168170
for i := 0; i < numToCreate; i++ {
169-
if m.combined.IsVFIOMode() {
171+
if m.IsVFIOMode() {
170172
err = parent.CreateVGPUDevice(vgpuType, strconv.Itoa(i))
171173
if err != nil {
172-
return fmt.Errorf("unable to create %s vGPU device on parent device %s: %v", key, parent.GetPhysicalFunction().Address, err)
174+
return fmt.Errorf("unable to create %s vGPU device on parent device %s: %v", vgpuType, parent.GetPhysicalFunction().Address, err)
173175
}
174176
} else {
175177
err = parent.CreateVGPUDevice(vgpuType, uuid.New().String())

pkg/vgpu/config.go

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,11 @@ package vgpu
1818

1919
import (
2020
"fmt"
21-
"strconv"
2221
"strings"
2322

2423
"github.com/NVIDIA/go-nvml/pkg/nvml"
2524
vgpu_combined "github.com/NVIDIA/vgpu-device-manager/internal/vgpu-combined"
2625
"github.com/NVIDIA/vgpu-device-manager/pkg/types"
27-
"github.com/google/uuid"
28-
"slices"
2926
)
3027

3128
const (
@@ -111,19 +108,7 @@ func (m *nvlibVGPUConfigManager) SetVGPUConfig(gpu int, config types.VGPUConfig)
111108
if ret != nvml.SUCCESS {
112109
return fmt.Errorf("failed to get supported vGPUs: %v", nvml.ErrorString(ret))
113110
}
114-
115-
for key := range config {
116-
found := false
117-
for _, vgpuTypeId := range supportedVGPUs {
118-
if vgpuTypeId.GetName() == key {
119-
found = true
120-
}
121-
}
122-
if !found {
123-
return fmt.Errorf("vGPU type %s is not supported on GPU (index=%d, address=%s)", key, gpu, device.Address)
124-
}
125-
}
126-
111+
127112
// Before deleting any existing vGPU devices, ensure all vGPU types specified in
128113
// the config are supported for the GPU we are applying the configuration to.
129114
//
@@ -137,11 +122,21 @@ func (m *nvlibVGPUConfigManager) SetVGPUConfig(gpu int, config types.VGPUConfig)
137122
sanitizedConfig := types.VGPUConfig{}
138123
for key, val := range config {
139124
strippedKey := stripVGPUConfigSuffix(key)
140-
if keyAvailable, err := parents[0].IsVGPUTypeAvailable(key); err == nil && keyAvailable {
141-
sanitizedConfig[key] = val
142-
} else if strippedKeyAvailable, err := parents[0].IsVGPUTypeAvailable(strippedKey); err == nil && strippedKeyAvailable {
143-
sanitizedConfig[strippedKey] = val
144-
} else {
125+
found := false
126+
for _, vgpuTypeId := range supportedVGPUs {
127+
name, ret := vgpuTypeId.GetName()
128+
if ret == nvml.SUCCESS && name == key {
129+
found = true
130+
sanitizedConfig[key] = val
131+
break
132+
}
133+
if ret == nvml.SUCCESS && name == strippedKey {
134+
found = true
135+
sanitizedConfig[strippedKey] = val
136+
break
137+
}
138+
}
139+
if !found {
145140
return fmt.Errorf("vGPU type %s is not supported on GPU (index=%d, address=%s)", key, gpu, device.Address)
146141
}
147142
}
@@ -152,14 +147,14 @@ func (m *nvlibVGPUConfigManager) SetVGPUConfig(gpu int, config types.VGPUConfig)
152147
}
153148

154149
for key, val := range sanitizedConfig {
155-
156150
creatableVGPUs, ret := nvmlDevice.GetCreatableVgpus()
157151
if ret != nvml.SUCCESS {
158152
return fmt.Errorf("failed to get creatable vGPUs: %v", nvml.ErrorString(ret))
159153
}
160154
found := false
161155
for _, vgpuTypeId := range creatableVGPUs {
162-
if vgpuTypeId.GetName() == key {
156+
name, ret := vgpuTypeId.GetName()
157+
if ret == nvml.SUCCESS && name == key {
163158
found = true
164159
break
165160
}

0 commit comments

Comments
 (0)