@@ -53,28 +53,35 @@ type GpuAllocator struct {
5353 dirtyQueueLock sync.Mutex
5454}
5555
56+ // AllocRequest encapsulates all parameters needed for GPU allocation
57+ type AllocRequest struct {
58+ // Name of the GPU pool to allocate from
59+ PoolName string
60+ // Namespace information for the workload
61+ WorkloadNameNamespace tfv1.NameNamespace
62+ // Resource requirements for the allocation
63+ Request tfv1.Resource
64+ // Number of GPUs to allocate
65+ Count uint
66+ // Specific GPU model to allocate, empty string means any model
67+ GPUModel string
68+ }
69+
5670// Alloc allocates a request to a gpu or multiple gpus from the same node.
57- func (s * GpuAllocator ) Alloc (
58- ctx context.Context ,
59- poolName string ,
60- workloadNameNamespace tfv1.NameNamespace ,
61- request tfv1.Resource ,
62- count uint ,
63- gpuModel string ,
64- ) ([]* tfv1.GPU , error ) {
71+ func (s * GpuAllocator ) Alloc (ctx context.Context , req AllocRequest ) ([]* tfv1.GPU , error ) {
6572 // Get GPUs from the pool using the in-memory store
66- poolGPUs := s .listGPUsFromPool (poolName )
73+ poolGPUs := s .listGPUsFromPool (req . PoolName )
6774
6875 // Add SameNodeFilter if count > 1 to ensure GPUs are from the same node
69- filterRegistry := s .filterRegistry .With (filter .NewResourceFilter (request ))
76+ filterRegistry := s .filterRegistry .With (filter .NewResourceFilter (req . Request ))
7077
7178 // Add GPU model filter if specified
72- if gpuModel != "" {
73- filterRegistry = filterRegistry .With (filter .NewGPUModelFilter (gpuModel ))
79+ if req . GPUModel != "" {
80+ filterRegistry = filterRegistry .With (filter .NewGPUModelFilter (req . GPUModel ))
7481 }
7582
76- if count > 1 {
77- filterRegistry = filterRegistry .With (filter .NewSameNodeFilter (count ))
83+ if req . Count > 1 {
84+ filterRegistry = filterRegistry .With (filter .NewSameNodeFilter (req . Count ))
7885 }
7986
8087 // Apply the filters in sequence
@@ -84,12 +91,12 @@ func (s *GpuAllocator) Alloc(
8491 }
8592
8693 if len (filteredGPUs ) == 0 {
87- return nil , fmt .Errorf ("no gpus available in pool %s after filtering" , poolName )
94+ return nil , fmt .Errorf ("no gpus available in pool %s after filtering" , req . PoolName )
8895 }
8996
9097 pool := & tfv1.GPUPool {}
91- if err := s .Get (ctx , client.ObjectKey {Name : poolName }, pool ); err != nil {
92- return nil , fmt .Errorf ("get pool %s: %w" , poolName , err )
98+ if err := s .Get (ctx , client.ObjectKey {Name : req . PoolName }, pool ); err != nil {
99+ return nil , fmt .Errorf ("get pool %s: %w" , req . PoolName , err )
93100 }
94101
95102 schedulingConfigTemplate := & tfv1.SchedulingConfigTemplate {}
@@ -100,7 +107,7 @@ func (s *GpuAllocator) Alloc(
100107 }
101108
102109 strategy := NewStrategy (schedulingConfigTemplate .Spec .Placement .Mode )
103- selectedGPUs , err := strategy .SelectGPUs (filteredGPUs , count )
110+ selectedGPUs , err := strategy .SelectGPUs (filteredGPUs , req . Count )
104111 if err != nil {
105112 return nil , fmt .Errorf ("select GPU: %w" , err )
106113 }
@@ -121,11 +128,11 @@ func (s *GpuAllocator) Alloc(
121128 }
122129
123130 // reduce available resource on the GPU status
124- gpu .Status .Available .Tflops .Sub (request .Tflops )
125- gpu .Status .Available .Vram .Sub (request .Vram )
131+ gpu .Status .Available .Tflops .Sub (req . Request .Tflops )
132+ gpu .Status .Available .Vram .Sub (req . Request .Vram )
126133
127134 if ! appAdded {
128- addRunningApp (ctx , gpu , workloadNameNamespace )
135+ addRunningApp (ctx , gpu , req . WorkloadNameNamespace )
129136 appAdded = true
130137 }
131138
0 commit comments