@@ -135,68 +135,38 @@ def get_gpu_local_binding(self, rule: dict, local_rank: int, gpus_per_task: int)
135135 """
136136 Calculate a 'gpu-local' binding using the topology-aware ordered GPU list.
137137 """
138- if not self .ordered_gpus :
139- raise RuntimeError ("Shape specifies 'bind: gpu-local', but no GPUs were discovered." )
140-
141- # Assign a slice of GPUs from the canonical, ordered list.
142- start_idx = local_rank * gpus_per_task
143- end_idx = start_idx + gpus_per_task
144- if end_idx > len (self .ordered_gpus ):
145- raise ValueError (f"Not enough total GPUs to satisfy request for rank { local_rank } ." )
146-
147- assigned_gpu_slice = self .ordered_gpus [start_idx :end_idx ]
148- cuda_devices = "," .join ([str (start_idx + i ) for i , _ in enumerate (assigned_gpu_slice )])
149-
138+ assignment = gpus .GPUAssignment .for_rank (local_rank , gpus_per_task , self .ordered_gpus )
139+
150140 # The CPU domain is the union of NUMA nodes for the assigned GPUs.
151- local_numa_indices = sorted (list ({gpu ["numa_index" ] for gpu in assigned_gpu_slice }))
152- domain_locations = [f"numa:{ i } " for i in local_numa_indices ]
153- domain = " " .join (domain_locations ) # e.g., "numa:0" or "numa:0 numa:1"
154-
155- # Get the final CPU binding WITHIN that domain.
141+ domain_locations = [f"numa:{ i } " for i in assignment .numa_indices ]
142+ domain = " " .join (domain_locations )
156143 cpu_binding_string = self .get_binding_in_gpu_domain (rule , local_rank , gpus_per_task , domain )
157- return f"{ cpu_binding_string } ;{ cuda_devices } "
158-
144+ return f"{ cpu_binding_string } ;{ assignment . cuda_devices } "
145+
159146 def get_gpu_remote_binding (self , rule : dict , local_rank : int , gpus_per_task : int ) -> str :
160147 """
161148 Calculates a 'gpu-remote' binding using the topology-aware ordered GPU list.
162149 """
163150 if len (self .numa_node_cpusets ) < 2 :
164151 raise RuntimeError ("'bind: gpu-remote' is invalid on a single-NUMA system." )
165- if not self .ordered_gpus :
166- raise RuntimeError ("Shape specifies 'bind: gpu-remote', but no GPUs were discovered." )
167-
168- # Assign a slice of GPUs to determine the local NUMA domains.
169- start_idx = local_rank * gpus_per_task
170- end_idx = start_idx + gpus_per_task
171- if end_idx > len (self .ordered_gpus ):
172- raise ValueError (f"Not enough total GPUs to satisfy request for rank { local_rank } ." )
152+ assignment = gpus .GPUAssignment .for_rank (local_rank , gpus_per_task , self .ordered_gpus )
173153
174- assigned_gpu_slice = self .ordered_gpus [start_idx :end_idx ]
175- cuda_devices = "," .join ([str (start_idx + i ) for i , _ in enumerate (assigned_gpu_slice )])
176-
177- # Find the set of all local NUMA domains for this rank's GPUs.
178- local_numa_indices = {gpu ["numa_index" ] for gpu in assigned_gpu_slice }
179-
180- # Find all remote NUMA domains.
154+ # Find all remote NUMA domains relative to the set of local domains.
181155 all_numa_indices = set (range (len (self .numa_node_cpusets )))
182- remote_numa_indices = sorted (list (all_numa_indices - local_numa_indices ))
183-
156+ remote_numa_indices = sorted (list (all_numa_indices - assignment . numa_indices ))
157+
184158 if not remote_numa_indices :
185- raise RuntimeError (
186- f"Cannot find a remote NUMA node for rank { local_rank } ; its GPUs span all NUMA domains."
187- )
188-
189- # 4. Select the target remote domain.
190- offset = rule .get ("offset" , 0 )
159+ raise RuntimeError (f"Cannot find a remote NUMA node for rank { local_rank } ; its GPUs span all NUMA domains." )
160+
161+ offset = rule .get ('offset' , 0 )
191162 if offset >= len (remote_numa_indices ):
192163 raise ValueError (f"Offset { offset } is out of range for remote NUMA domains." )
193-
164+
194165 target_remote_numa_idx = remote_numa_indices [offset ]
195166 domain = f"numa:{ target_remote_numa_idx } "
196167
197- # Get the final CPU binding WITHIN that remote domain.
198168 cpu_binding_string = self .get_binding_in_gpu_domain (rule , local_rank , gpus_per_task , domain )
199- return f"{ cpu_binding_string } ;{ cuda_devices } "
169+ return f"{ cpu_binding_string } ;{ assignment . cuda_devices } "
200170
201171 def get_binding_in_gpu_domain (
202172 self , rule : dict , local_rank : int , gpus_per_task : int , domain : str
0 commit comments