@@ -137,13 +137,13 @@ def get_gpu_local_binding(self, rule: dict, local_rank: int, gpus_per_task: int)
137137 Calculate a 'gpu-local' binding using the topology-aware ordered GPU list.
138138 """
139139 assignment = gpus .GPUAssignment .for_rank (local_rank , gpus_per_task , self .ordered_gpus )
140-
140+
141141 # The CPU domain is the union of NUMA nodes for the assigned GPUs.
142142 domain_locations = [f"numa:{ i } " for i in assignment .numa_indices ]
143143 domain = " " .join (domain_locations )
144144 cpu_binding_string = self .get_binding_in_gpu_domain (rule , local_rank , gpus_per_task , domain )
145145 return f"{ cpu_binding_string } ;{ assignment .cuda_devices } "
146-
146+
147147 def get_gpu_remote_binding (self , rule : dict , local_rank : int , gpus_per_task : int ) -> str :
148148 """
149149 Calculates a 'gpu-remote' binding using the topology-aware ordered GPU list.
@@ -155,14 +155,16 @@ def get_gpu_remote_binding(self, rule: dict, local_rank: int, gpus_per_task: int
155155 # Find all remote NUMA domains relative to the set of local domains.
156156 all_numa_indices = set (range (len (self .numa_node_cpusets )))
157157 remote_numa_indices = sorted (list (all_numa_indices - assignment .numa_indices ))
158-
158+
159159 if not remote_numa_indices :
160- raise RuntimeError (f"Cannot find a remote NUMA node for rank { local_rank } ; its GPUs span all NUMA domains." )
161-
162- offset = rule .get ('offset' , 0 )
160+ raise RuntimeError (
161+ f"Cannot find a remote NUMA node for rank { local_rank } ; its GPUs span all NUMA domains."
162+ )
163+
164+ offset = rule .get ("offset" , 0 )
163165 if offset >= len (remote_numa_indices ):
164166 raise ValueError (f"Offset { offset } is out of range for remote NUMA domains." )
165-
167+
166168 target_remote_numa_idx = remote_numa_indices [offset ]
167169 domain = f"numa:{ target_remote_numa_idx } "
168170
@@ -171,7 +173,7 @@ def get_gpu_remote_binding(self, rule: dict, local_rank: int, gpus_per_task: int
171173
172174 def get_binding_in_gpu_domain (
173175 self , rule : dict , local_rank : int , gpus_per_task : int , domain : str
174- ) -> str :
176+ ):
175177 """
176178 A dedicated binding engine for GPU jobs. It applies user preferences within a calculated domain
177179 (e.g., "numa:0" or "numa:0 numa:1").
@@ -184,25 +186,58 @@ def get_binding_in_gpu_domain(
184186 # If a broad type is requested, the binding is the domain itself.
185187 return domain
186188
187- if "prefer" in rule :
188- try :
189- requested_index = int (rule ["prefer" ])
190- # Validate by attempting to get the object.
191- return commands .hwloc_calc .get_object_in_set (domain , hwloc_type , requested_index )
192- except (ValueError , RuntimeError , TypeError ):
193- print (
194- f"Warning: Preferred index '{ rule ['prefer' ]} ' invalid/not in domain '{ domain } '. Falling back." ,
195- file = sys .stderr ,
189+ elif hwloc_type in ["core" , "pu" , "l2cache" ]:
190+
191+ # Get the number of objects to select, defaulting to 1.
192+ count = rule .get ("count" , 1 )
193+
194+ all_indices_in_domain = commands .hwloc_calc .get_object_in_set (
195+ domain , hwloc_type , "all"
196+ ).split ("," )
197+ if not all_indices_in_domain or not all_indices_in_domain [0 ]:
198+ raise RuntimeError (f"No objects of type '{ hwloc_type } ' found in domain '{ domain } '." )
199+
200+ if "prefer" in rule :
201+ if count > 1 :
202+ raise ValueError ("'prefer' and 'count > 1' cannot be used together." )
203+ try :
204+ requested_index = str (int (rule ["prefer" ]))
205+ if requested_index in all_indices_in_domain :
206+ return f"{ hwloc_type } :{ requested_index } "
207+ else :
208+ print (
209+ f"Warning: Preferred index '{ requested_index } ' not available in domain '{ domain } '. Falling back." ,
210+ file = sys .stderr ,
211+ )
212+ except (ValueError , TypeError ):
213+ raise ValueError (
214+ f"The 'prefer' key must be a simple integer, but got: { rule ['prefer' ]} "
215+ )
216+
217+ # Default assignment: Calculate the slice of objects for this rank.
218+ # We need to know this rank's turn on the current domain.
219+ num_domains = len (domain .split ())
220+ rank_turn_in_domain = local_rank // num_domains
221+
222+ start_index = rank_turn_in_domain * count
223+ end_index = start_index + count
224+
225+ if end_index > len (all_indices_in_domain ):
226+ raise ValueError (
227+ f"Not enough '{ hwloc_type } ' objects in domain '{ domain } ' to satisfy request "
228+ f"for { count } objects for rank { local_rank } (needs up to index { end_index - 1 } , "
229+ f"only { len (all_indices_in_domain )} available)."
196230 )
197231
198- # Default assignment: Rank's Nth turn for a resource of this type within its GPU group.
199- # This is the correct index for packing sub-objects within a domain.
200- index = local_rank // gpus_per_task if gpus_per_task > 0 else local_rank
232+ # Get the slice of object indices.
233+ target_indices_slice = all_indices_in_domain [start_index :end_index ]
201234
202- # For certain patterns like interleave or spread, the index calculation
203- # would need to be more complex, but for a simple packed pattern this is the logic.
204- # Let's assume a simple packed logic for now as pattern is not yet implemented here.
205- return commands .hwloc_calc .get_object_in_set (domain , hwloc_type , index )
235+ # Construct a space-separated list of location objects.
236+ # e.g., "core:0 core:1 core:2 core:3 core:4 core:5"
237+ binding_locations = [f"{ hwloc_type } :{ i } " for i in target_indices_slice ]
238+ return " " .join (binding_locations )
239+ else :
240+ raise ValueError (f"Unsupported type '{ hwloc_type } ' for GPU binding." )
206241
207242 def get_binding_for_rank (self , rank , node_id , local_rank , gpus_per_task = None ) -> str :
208243 """
0 commit comments