@@ -40,9 +40,6 @@ type Manager interface {
4040 GetVGPUConfig (gpu int ) (types.VGPUConfig , error )
4141 SetVGPUConfig (gpu int , config types.VGPUConfig ) error
4242 ClearVGPUConfig (gpu int ) error
43- IsVFIOEnabled () (bool , error )
44- GetVGPUConfigforVFIO (gpu int ) (types.VGPUConfig , error )
45- SetVGPUConfigforVFIO (gpu int , config types.VGPUConfig ) error
4643}
4744
4845type nvlibVGPUConfigManager struct {
@@ -56,36 +53,57 @@ func NewNvlibVGPUConfigManager() Manager {
5653 return & nvlibVGPUConfigManager {nvlib .New ()}
5754}
5855
56+ func (m * nvlibVGPUConfigManager ) GetVGPUConfig (gpu int ) (types.VGPUConfig , error ) {
57+ IsVFIOEnabled , err := m .IsVFIOEnabled ()
58+ if err != nil {
59+ return nil , fmt .Errorf ("error checking if VFIO is enabled: %v" , err )
60+ }
61+ if IsVFIOEnabled {
62+ return m .GetVGPUConfigforVFIO (gpu )
63+ }
64+ return m .GetVGPUConfigforMDEV (gpu )
65+ }
66+
67+ func (m * nvlibVGPUConfigManager ) SetVGPUConfig (gpu int , config types.VGPUConfig ) error {
68+ IsVFIOEnabled , err := m .IsVFIOEnabled ()
69+ if err != nil {
70+ return fmt .Errorf ("error checking if VFIO is enabled: %v" , err )
71+ }
72+ if IsVFIOEnabled {
73+ return m .SetVGPUConfigforVFIO (gpu , config )
74+ }
75+ return m .SetVGPUConfigforMDEV (gpu , config )
76+ }
77+
78+
5979func (m * nvlibVGPUConfigManager ) GetVGPUConfigforVFIO (gpu int ) (types.VGPUConfig , error ) {
6080 nvdevice , err := m .nvlib .Nvpci .GetGPUByIndex (gpu )
6181 if err != nil {
6282 return nil , fmt .Errorf ("unable to get GPU by index %d: %v" , gpu , err )
6383 }
6484 vgpuConfig := types.VGPUConfig {}
6585 vfnum := 0
66- if nvdevice .SriovInfo .PhysicalFunction == nil {
86+ if nvdevice .SriovInfo .IsVF () {
6787 return vgpuConfig , nil
6888 }
69- totalVF := int (nvdevice .SriovInfo .PhysicalFunction .TotalVFs )
70- for vfnum < totalVF {
71- vfAddr := HostPCIDevicesRoot + "/" + nvdevice .Address + "/ virtfn" + strconv .Itoa (vfnum ) + "/ nvidia"
89+ numVF := int (nvdevice .SriovInfo .PhysicalFunction .NumVFs )
90+ for vfnum < numVF {
91+ vfAddr := filepath . Join ( HostPCIDevicesRoot , nvdevice .Address , " virtfn", strconv .Itoa (vfnum ), " nvidia")
7292 if _ , err := os .Stat (vfAddr ); err != nil {
73- vfnum ++
74- continue
93+ return nil , fmt .Errorf ("Virtual Function %d at address %s does not exist" , vfnum , vfAddr )
7594 }
76- vgpuTypeNumberBytes , err := os .ReadFile (vfAddr + "/ current_vgpu_type" )
95+ vgpuTypeNumberBytes , err := os .ReadFile (filepath . Join ( vfAddr , " current_vgpu_type") )
7796 if err != nil {
7897 return nil , fmt .Errorf ("unable to read current vGPU type: %v" , err )
7998 }
8099 vgpuTypeNumber , err := strconv .Atoi (strings .TrimSpace (string (vgpuTypeNumberBytes )))
81100 if err != nil {
82- return nil , fmt .Errorf ("unable to convert current vGPU type to int: %v" , err )
101+ return nil , fmt .Errorf ("unable to convert current vGPU type number to int: %v" , err )
83102 }
84103 if vgpuTypeNumber == 0 {
85- vfnum ++
86- continue
104+ return nil , fmt .Errorf ("Virtual Function %d at address %s has no vGPU type assigned" , vfnum , vfAddr )
87105 }
88- vgpuTypeName , err := m .getVGPUTypeNameforVFIO ( vfAddr + "/ creatable_vgpu_types" , vgpuTypeNumber )
106+ vgpuTypeName , err := m .getVGPUTypeNameForId ( filepath . Join ( vfAddr , " creatable_vgpu_types") , vgpuTypeNumber )
89107 if err != nil {
90108 return nil , fmt .Errorf ("unable to get vGPU type name: %v" , err )
91109 }
@@ -118,22 +136,16 @@ func (m *nvlibVGPUConfigManager) SetVGPUConfigforVFIO(gpu int, config types.VGPU
118136 return fmt .Errorf ("GPU at index %d not found in available NVIDIA devices" , gpu )
119137 }
120138
121- cmd := exec .Command ("chroot" , "/host" , "/run/nvidia/driver/usr/lib/nvidia/sriov-manage" , "-e" , nvdevice .Address )
122- output , err := cmd .CombinedOutput ()
123- if err != nil {
124- return fmt .Errorf ("unable to execute sriov-manage: %v, output: %s" , err , string (output ))
125- }
126-
127139 vfnum := 0
128140 for key , val := range config {
129141 remainingToCreate := val
130142 for remainingToCreate > 0 {
131- vfAddr := HostPCIDevicesRoot + "/" + nvdevice .Address + "/ virtfn" + strconv .Itoa (vfnum ) + "/ nvidia"
132- number , err := m .getVGPUTypeNumberforVFIO ( vfAddr + "/ creatable_vgpu_types" , key )
143+ vfAddr := filepath . Join ( HostPCIDevicesRoot , nvdevice .Address , " virtfn", strconv .Itoa (vfnum ), " nvidia")
144+ number , err := m .getIdForVGPUTypeName ( filepath . Join ( vfAddr , " creatable_vgpu_types") , key )
133145 if err != nil {
134146 return fmt .Errorf ("unable to get vGPU type number: %v" , err )
135147 }
136- err = os .WriteFile (vfAddr + "/ current_vgpu_type" , []byte (strconv .Itoa (number )), 0644 )
148+ err = os .WriteFile (filepath . Join ( vfAddr , " current_vgpu_type") , []byte (strconv .Itoa (number )), 0644 )
137149 if err != nil {
138150 return fmt .Errorf ("unable to write current vGPU type: %v" , err )
139151 }
@@ -144,7 +156,7 @@ func (m *nvlibVGPUConfigManager) SetVGPUConfigforVFIO(gpu int, config types.VGPU
144156 return nil
145157}
146158
147- func (m * nvlibVGPUConfigManager ) getVGPUTypeNameforVFIO (filePath string , vgpuTypeNumber int ) (string , error ) {
159+ func (m * nvlibVGPUConfigManager ) getVGPUTypeNameForId (filePath string , vgpuTypeNumber int ) (string , error ) {
148160 file , err := os .Open (filePath )
149161 if err != nil {
150162 return "" , fmt .Errorf ("unable to open file %s: %v" , filePath , err )
@@ -167,7 +179,7 @@ func (m *nvlibVGPUConfigManager) getVGPUTypeNameforVFIO(filePath string, vgpuTyp
167179 return "" , fmt .Errorf ("vGPU type %d not found in file %s" , vgpuTypeNumber , filePath )
168180}
169181
170- func (m * nvlibVGPUConfigManager ) getVGPUTypeNumberforVFIO (filePath string , vgpuTypeName string ) (int , error ) {
182+ func (m * nvlibVGPUConfigManager ) getIdForVGPUTypeName (filePath string , vgpuTypeName string ) (int , error ) {
171183 file , err := os .Open (filePath )
172184 if err != nil {
173185 return 0 , fmt .Errorf ("unable to open file %s: %v" , filePath , err )
@@ -203,7 +215,7 @@ func IsVFIOEnabled() (bool, error) {
203215}
204216
205217// GetVGPUConfig gets the 'VGPUConfig' currently applied to a GPU at a particular index
206- func (m * nvlibVGPUConfigManager ) GetVGPUConfig (gpu int ) (types.VGPUConfig , error ) {
218+ func (m * nvlibVGPUConfigManager ) GetVGPUConfigforMDEV (gpu int ) (types.VGPUConfig , error ) {
207219 device , err := m .nvlib .Nvpci .GetGPUByIndex (gpu )
208220 if err != nil {
209221 return nil , fmt .Errorf ("error getting device at index '%d': %v" , gpu , err )
@@ -226,7 +238,7 @@ func (m *nvlibVGPUConfigManager) GetVGPUConfig(gpu int) (types.VGPUConfig, error
226238}
227239
228240// SetVGPUConfig applies the selected `VGPUConfig` to a GPU at a particular index if it is not already applied
229- func (m * nvlibVGPUConfigManager ) SetVGPUConfig (gpu int , config types.VGPUConfig ) error {
241+ func (m * nvlibVGPUConfigManager ) SetVGPUConfigforMDEV (gpu int , config types.VGPUConfig ) error {
230242 device , err := m .nvlib .Nvpci .GetGPUByIndex (gpu )
231243 if err != nil {
232244 return fmt .Errorf ("error getting device at index '%d': %v" , gpu , err )
0 commit comments