Skip to content

Commit 8e19d45

Browse files
committed
GPU load balancing
1 parent e018522 commit 8e19d45

File tree

2 files changed

+175
-30
lines changed

2 files changed

+175
-30
lines changed

cmd/api/config/config.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ type Config struct {
118118
// Hypervisor configuration
119119
DefaultHypervisor string // Default hypervisor type: "cloud-hypervisor" or "qemu"
120120

121+
// GPU configuration
122+
GPUProfileCacheTTL string // TTL for GPU profile metadata cache (e.g., "30m")
123+
121124
// Oversubscription ratios (1.0 = no oversubscription, 2.0 = 2x oversubscription)
122125
OversubCPU float64 // CPU oversubscription ratio
123126
OversubMemory float64 // Memory oversubscription ratio
@@ -212,6 +215,9 @@ func Load() *Config {
212215
// Hypervisor configuration
213216
DefaultHypervisor: getEnv("DEFAULT_HYPERVISOR", "cloud-hypervisor"),
214217

218+
// GPU configuration
219+
GPUProfileCacheTTL: getEnv("GPU_PROFILE_CACHE_TTL", "30m"),
220+
215221
// Oversubscription ratios (1.0 = no oversubscription)
216222
OversubCPU: getEnvFloat("OVERSUB_CPU", 4.0),
217223
OversubMemory: getEnvFloat("OVERSUB_MEMORY", 1.0),

lib/devices/mdev.go

Lines changed: 169 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ import (
88
"os/exec"
99
"path/filepath"
1010
"regexp"
11+
"sort"
1112
"strconv"
1213
"strings"
1314
"sync"
15+
"time"
1416

1517
"github.com/google/uuid"
1618
"github.com/kernel/hypeman/lib/logger"
@@ -32,12 +34,52 @@ type profileMetadata struct {
3234
FramebufferMB int
3335
}
3436

35-
// cachedProfiles holds static profile metadata, loaded once on first access
37+
// cachedProfiles holds profile metadata with TTL-based expiry.
38+
// The cache TTL is configurable via GPU_PROFILE_CACHE_TTL environment variable.
3639
var (
3740
cachedProfiles []profileMetadata
38-
cachedProfilesOnce sync.Once
41+
cachedProfilesMu sync.RWMutex
42+
cachedProfilesTime time.Time
3943
)
4044

45+
// getProfileCacheTTL returns the TTL for profile metadata cache.
46+
// Reads from GPU_PROFILE_CACHE_TTL env var, defaults to 30 minutes.
47+
func getProfileCacheTTL() time.Duration {
48+
if ttl := os.Getenv("GPU_PROFILE_CACHE_TTL"); ttl != "" {
49+
if d, err := time.ParseDuration(ttl); err == nil {
50+
return d
51+
}
52+
}
53+
return 30 * time.Minute
54+
}
55+
56+
// getCachedProfiles returns cached profile metadata, refreshing if TTL has expired.
57+
func getCachedProfiles(firstVF string) []profileMetadata {
58+
ttl := getProfileCacheTTL()
59+
60+
// Fast path: check with read lock
61+
cachedProfilesMu.RLock()
62+
if len(cachedProfiles) > 0 && time.Since(cachedProfilesTime) < ttl {
63+
profiles := cachedProfiles
64+
cachedProfilesMu.RUnlock()
65+
return profiles
66+
}
67+
cachedProfilesMu.RUnlock()
68+
69+
// Slow path: refresh cache with write lock
70+
cachedProfilesMu.Lock()
71+
defer cachedProfilesMu.Unlock()
72+
73+
// Double-check after acquiring write lock
74+
if len(cachedProfiles) > 0 && time.Since(cachedProfilesTime) < ttl {
75+
return cachedProfiles
76+
}
77+
78+
cachedProfiles = loadProfileMetadata(firstVF)
79+
cachedProfilesTime = time.Now()
80+
return cachedProfiles
81+
}
82+
4183
// DiscoverVFs returns all SR-IOV Virtual Functions available for vGPU.
4284
// These are discovered by scanning /sys/class/mdev_bus/ which contains
4385
// VFs that can host mdev devices.
@@ -100,17 +142,15 @@ func ListGPUProfilesWithVFs(vfs []VirtualFunction) ([]GPUProfile, error) {
100142
return nil, nil
101143
}
102144

103-
// Load static profile metadata once (cached indefinitely)
104-
cachedProfilesOnce.Do(func() {
105-
cachedProfiles = loadProfileMetadata(vfs[0].PCIAddress)
106-
})
145+
// Load profile metadata with TTL-based caching
146+
cachedMeta := getCachedProfiles(vfs[0].PCIAddress)
107147

108148
// Count availability for all profiles in parallel
109-
availability := countAvailableVFsForProfilesParallel(vfs, cachedProfiles)
149+
availability := countAvailableVFsForProfilesParallel(vfs, cachedMeta)
110150

111151
// Build result with dynamic availability counts
112-
profiles := make([]GPUProfile, 0, len(cachedProfiles))
113-
for _, meta := range cachedProfiles {
152+
profiles := make([]GPUProfile, 0, len(cachedMeta))
153+
for _, meta := range cachedMeta {
114154
profiles = append(profiles, GPUProfile{
115155
Name: meta.Name,
116156
FramebufferMB: meta.FramebufferMB,
@@ -194,8 +234,8 @@ func parseFramebufferFromDescription(typeDir string) int {
194234
}
195235

196236
// countAvailableVFsForProfilesParallel counts available instances for all profiles in parallel.
197-
// Optimized: all VFs on the same parent GPU have identical profile support,
198-
// so we only sample one VF per parent instead of reading from every VF.
237+
// Groups VFs by parent GPU, then sums available_instances across all free VFs.
238+
// For SR-IOV vGPU, each VF typically has available_instances of 0 or 1.
199239
func countAvailableVFsForProfilesParallel(vfs []VirtualFunction, profiles []profileMetadata) map[string]int {
200240
if len(vfs) == 0 || len(profiles) == 0 {
201241
return make(map[string]int)
@@ -352,6 +392,118 @@ func getProfileNameFromType(profileType, vfAddress string) string {
352392
return strings.TrimSpace(string(data))
353393
}
354394

395+
// getProfileFramebufferMB returns the framebuffer size in MB for a profile type.
396+
// Uses cached profile metadata for fast lookup.
397+
func getProfileFramebufferMB(profileType string) int {
398+
cachedProfilesMu.RLock()
399+
defer cachedProfilesMu.RUnlock()
400+
401+
for _, p := range cachedProfiles {
402+
if p.TypeName == profileType {
403+
return p.FramebufferMB
404+
}
405+
}
406+
return 0
407+
}
408+
409+
// calculateGPUVRAMUsage calculates VRAM usage per GPU from active mdevs.
410+
// Returns a map of parentGPU -> usedVRAMMB.
411+
func calculateGPUVRAMUsage(vfs []VirtualFunction, mdevs []MdevDevice) map[string]int {
412+
// Build VF -> parentGPU lookup
413+
vfToParent := make(map[string]string, len(vfs))
414+
for _, vf := range vfs {
415+
vfToParent[vf.PCIAddress] = vf.ParentGPU
416+
}
417+
418+
// Sum framebuffer usage per GPU
419+
usageByGPU := make(map[string]int)
420+
for _, mdev := range mdevs {
421+
parentGPU := vfToParent[mdev.VFAddress]
422+
if parentGPU == "" {
423+
continue
424+
}
425+
usageByGPU[parentGPU] += getProfileFramebufferMB(mdev.ProfileType)
426+
}
427+
428+
return usageByGPU
429+
}
430+
431+
// selectLeastLoadedVF selects a VF from the GPU with the most available VRAM
432+
// that can create the requested profile. Returns empty string if none available.
433+
func selectLeastLoadedVF(ctx context.Context, vfs []VirtualFunction, profileType string) string {
434+
log := logger.FromContext(ctx)
435+
436+
// Get active mdevs to calculate VRAM usage
437+
mdevs, _ := ListMdevDevices()
438+
439+
// Calculate VRAM usage per GPU
440+
vramUsage := calculateGPUVRAMUsage(vfs, mdevs)
441+
442+
// Group free VFs by parent GPU
443+
freeVFsByGPU := make(map[string][]VirtualFunction)
444+
allGPUs := make(map[string]bool)
445+
for _, vf := range vfs {
446+
allGPUs[vf.ParentGPU] = true
447+
if !vf.HasMdev {
448+
freeVFsByGPU[vf.ParentGPU] = append(freeVFsByGPU[vf.ParentGPU], vf)
449+
}
450+
}
451+
452+
// Build list of GPUs sorted by VRAM usage (ascending = least loaded first)
453+
type gpuLoad struct {
454+
gpu string
455+
usedMB int
456+
}
457+
var gpuLoads []gpuLoad
458+
for gpu := range allGPUs {
459+
gpuLoads = append(gpuLoads, gpuLoad{gpu: gpu, usedMB: vramUsage[gpu]})
460+
}
461+
sort.Slice(gpuLoads, func(i, j int) bool {
462+
return gpuLoads[i].usedMB < gpuLoads[j].usedMB
463+
})
464+
465+
log.DebugContext(ctx, "GPU VRAM usage for load balancing",
466+
"gpu_count", len(gpuLoads),
467+
"profile_type", profileType)
468+
469+
// Try each GPU in order of least loaded
470+
for _, gl := range gpuLoads {
471+
freeVFs := freeVFsByGPU[gl.gpu]
472+
if len(freeVFs) == 0 {
473+
log.DebugContext(ctx, "skipping GPU: no free VFs",
474+
"gpu", gl.gpu,
475+
"used_mb", gl.usedMB)
476+
continue
477+
}
478+
479+
// Check if any free VF on this GPU can create the profile
480+
for _, vf := range freeVFs {
481+
availPath := filepath.Join(mdevBusPath, vf.PCIAddress, "mdev_supported_types", profileType, "available_instances")
482+
data, err := os.ReadFile(availPath)
483+
if err != nil {
484+
continue
485+
}
486+
instances, err := strconv.Atoi(strings.TrimSpace(string(data)))
487+
if err != nil || instances < 1 {
488+
continue
489+
}
490+
491+
log.DebugContext(ctx, "selected VF from least loaded GPU",
492+
"vf", vf.PCIAddress,
493+
"gpu", gl.gpu,
494+
"gpu_used_mb", gl.usedMB)
495+
return vf.PCIAddress
496+
}
497+
498+
log.DebugContext(ctx, "skipping GPU: no VF can create profile",
499+
"gpu", gl.gpu,
500+
"used_mb", gl.usedMB,
501+
"profile_type", profileType)
502+
}
503+
504+
return ""
505+
}
506+
355507
// CreateMdev creates an mdev device for the given profile and instance.
356508
// It finds an available VF and creates the mdev, returning the device info.
357509
// This function is thread-safe and uses a mutex to prevent race conditions
@@ -369,32 +521,19 @@ func CreateMdev(ctx context.Context, profileName, instanceID string) (*MdevDevic
369521
return nil, err
370522
}
371523

372-
// Find an available VF
524+
// Discover all VFs
373525
vfs, err := DiscoverVFs()
374526
if err != nil {
375527
return nil, fmt.Errorf("discover VFs: %w", err)
376528
}
377529

378-
var targetVF string
379-
for _, vf := range vfs {
380-
// Skip VFs that already have an mdev
381-
if vf.HasMdev {
382-
continue
383-
}
384-
// Check if this VF can create the profile
385-
availPath := filepath.Join(mdevBusPath, vf.PCIAddress, "mdev_supported_types", profileType, "available_instances")
386-
data, err := os.ReadFile(availPath)
387-
if err != nil {
388-
continue
389-
}
390-
instances, err := strconv.Atoi(strings.TrimSpace(string(data)))
391-
if err != nil || instances < 1 {
392-
continue
393-
}
394-
targetVF = vf.PCIAddress
395-
break
530+
// Ensure profile cache is populated (needed for VRAM calculation)
531+
if len(vfs) > 0 {
532+
_ = getCachedProfiles(vfs[0].PCIAddress)
396533
}
397534

535+
// Select VF from the least loaded GPU (by VRAM usage)
536+
targetVF := selectLeastLoadedVF(ctx, vfs, profileType)
398537
if targetVF == "" {
399538
return nil, fmt.Errorf("no available VF for profile %q", profileName)
400539
}

0 commit comments

Comments
 (0)