Skip to content

Commit 66e0f21

Browse files
committed
gpu: make shared assignment class
Signed-off-by: vsoch <[email protected]>
1 parent 7342925 commit 66e0f21

File tree

2 files changed

+44
-69
lines changed

2 files changed

+44
-69
lines changed

fluxbind/shape/gpu.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,47 @@
11
from dataclasses import dataclass
22

3-
43
@dataclass
54
class GPUAssignment:
65
"""
7-
Data structure to hold information about a rank's assigned GPU.
6+
A data structure to hold information about a rank's assigned GPU(s).
7+
Instances are created via the for_rank() classmethod.
88
"""
9-
10-
indices: list[int] # logical index of the GPU
11-
pci_ids: list[str] # The PCI bus ID of the GPU
12-
cuda_devices: str # CUDA_VISIBLE_DEVICES
9+
indices: list[int] # The logical indices in the ordered list (e.g., [4, 5])
10+
pci_ids: list[str] # The corresponding PCI bus IDs of the GPUs
11+
numa_indices: set[int] # The set of unique NUMA nodes these GPUs are on
12+
cuda_devices: str # The final string for CUDA_VISIBLE_DEVICES (e.g., "4,5")
1313

1414
@classmethod
15-
def for_rank(cls, local_rank, gpus_per_task=None, gpu_pci_ids=None):
15+
def for_rank(
16+
cls,
17+
local_rank: int,
18+
gpus_per_task: int,
19+
ordered_gpus: list[dict]
20+
) -> "GPUAssignment":
1621
"""
17-
A factory method that assigns a GPU to a given local rank
18-
using a round-robin strategy.
22+
A factory method that assigns a slice of GPUs to a given local rank
23+
from a pre-ordered, topology-aware list of all GPUs.
1924
"""
20-
if not gpu_pci_ids:
25+
if not ordered_gpus:
2126
raise RuntimeError("Attempted to assign a GPU, but no GPUs were discovered.")
27+
28+
start_idx = local_rank * gpus_per_task
29+
end_idx = start_idx + gpus_per_task
2230

23-
# Assume one gpu per task, since we are calling this, period
24-
gpus_per_task = gpus_per_task or 1
25-
26-
# 1. Calculate the starting GPU index for this rank.
27-
start_gpu_index = local_rank * gpus_per_task
28-
end_gpu_index = start_gpu_index + gpus_per_task
29-
30-
if end_gpu_index > len(gpu_pci_ids):
31+
if end_idx > len(ordered_gpus):
3132
raise ValueError(
3233
f"Cannot satisfy request for {gpus_per_task} GPUs for local_rank {local_rank}. "
33-
f"Only {len(gpu_pci_ids)} GPUs available in total."
34+
f"Only {len(ordered_gpus)} GPUs available in total."
3435
)
3536

36-
# Return the assignment
37-
assigned_indices = list(range(start_gpu_index, end_gpu_index))
37+
assigned_gpu_slice = ordered_gpus[start_idx:end_idx]
38+
39+
# The global indices for CUDA_VISIBLE_DEVICES are their positions in the ordered list
40+
assigned_indices = list(range(start_idx, end_idx))
41+
3842
return cls(
3943
indices=assigned_indices,
40-
pci_ids=[gpu_pci_ids[i] for i in assigned_indices],
41-
cuda_devices=",".join([str(x) for x in assigned_indices]),
42-
)
44+
pci_ids=[gpu['pci_id'] for gpu in assigned_gpu_slice],
45+
numa_indices={gpu['numa_index'] for gpu in assigned_gpu_slice},
46+
cuda_devices=",".join(map(str, assigned_indices))
47+
)

fluxbind/shape/shape.py

Lines changed: 15 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -135,68 +135,38 @@ def get_gpu_local_binding(self, rule: dict, local_rank: int, gpus_per_task: int)
135135
"""
136136
Calculate a 'gpu-local' binding using the topology-aware ordered GPU list.
137137
"""
138-
if not self.ordered_gpus:
139-
raise RuntimeError("Shape specifies 'bind: gpu-local', but no GPUs were discovered.")
140-
141-
# Assign a slice of GPUs from the canonical, ordered list.
142-
start_idx = local_rank * gpus_per_task
143-
end_idx = start_idx + gpus_per_task
144-
if end_idx > len(self.ordered_gpus):
145-
raise ValueError(f"Not enough total GPUs to satisfy request for rank {local_rank}.")
146-
147-
assigned_gpu_slice = self.ordered_gpus[start_idx:end_idx]
148-
cuda_devices = ",".join([str(start_idx + i) for i, _ in enumerate(assigned_gpu_slice)])
149-
138+
assignment = gpus.GPUAssignment.for_rank(local_rank, gpus_per_task, self.ordered_gpus)
139+
150140
# The CPU domain is the union of NUMA nodes for the assigned GPUs.
151-
local_numa_indices = sorted(list({gpu["numa_index"] for gpu in assigned_gpu_slice}))
152-
domain_locations = [f"numa:{i}" for i in local_numa_indices]
153-
domain = " ".join(domain_locations) # e.g., "numa:0" or "numa:0 numa:1"
154-
155-
# Get the final CPU binding WITHIN that domain.
141+
domain_locations = [f"numa:{i}" for i in assignment.numa_indices]
142+
domain = " ".join(domain_locations)
156143
cpu_binding_string = self.get_binding_in_gpu_domain(rule, local_rank, gpus_per_task, domain)
157-
return f"{cpu_binding_string};{cuda_devices}"
158-
144+
return f"{cpu_binding_string};{assignment.cuda_devices}"
145+
159146
def get_gpu_remote_binding(self, rule: dict, local_rank: int, gpus_per_task: int) -> str:
160147
"""
161148
Calculates a 'gpu-remote' binding using the topology-aware ordered GPU list.
162149
"""
163150
if len(self.numa_node_cpusets) < 2:
164151
raise RuntimeError("'bind: gpu-remote' is invalid on a single-NUMA system.")
165-
if not self.ordered_gpus:
166-
raise RuntimeError("Shape specifies 'bind: gpu-remote', but no GPUs were discovered.")
167-
168-
# Assign a slice of GPUs to determine the local NUMA domains.
169-
start_idx = local_rank * gpus_per_task
170-
end_idx = start_idx + gpus_per_task
171-
if end_idx > len(self.ordered_gpus):
172-
raise ValueError(f"Not enough total GPUs to satisfy request for rank {local_rank}.")
152+
assignment = gpus.GPUAssignment.for_rank(local_rank, gpus_per_task, self.ordered_gpus)
173153

174-
assigned_gpu_slice = self.ordered_gpus[start_idx:end_idx]
175-
cuda_devices = ",".join([str(start_idx + i) for i, _ in enumerate(assigned_gpu_slice)])
176-
177-
# Find the set of all local NUMA domains for this rank's GPUs.
178-
local_numa_indices = {gpu["numa_index"] for gpu in assigned_gpu_slice}
179-
180-
# Find all remote NUMA domains.
154+
# Find all remote NUMA domains relative to the set of local domains.
181155
all_numa_indices = set(range(len(self.numa_node_cpusets)))
182-
remote_numa_indices = sorted(list(all_numa_indices - local_numa_indices))
183-
156+
remote_numa_indices = sorted(list(all_numa_indices - assignment.numa_indices))
157+
184158
if not remote_numa_indices:
185-
raise RuntimeError(
186-
f"Cannot find a remote NUMA node for rank {local_rank}; its GPUs span all NUMA domains."
187-
)
188-
189-
# 4. Select the target remote domain.
190-
offset = rule.get("offset", 0)
159+
raise RuntimeError(f"Cannot find a remote NUMA node for rank {local_rank}; its GPUs span all NUMA domains.")
160+
161+
offset = rule.get('offset', 0)
191162
if offset >= len(remote_numa_indices):
192163
raise ValueError(f"Offset {offset} is out of range for remote NUMA domains.")
193-
164+
194165
target_remote_numa_idx = remote_numa_indices[offset]
195166
domain = f"numa:{target_remote_numa_idx}"
196167

197-
# Get the final CPU binding WITHIN that remote domain.
198168
cpu_binding_string = self.get_binding_in_gpu_domain(rule, local_rank, gpus_per_task, domain)
199-
return f"{cpu_binding_string};{cuda_devices}"
169+
return f"{cpu_binding_string};{assignment.cuda_devices}"
200170

201171
def get_binding_in_gpu_domain(
202172
self, rule: dict, local_rank: int, gpus_per_task: int, domain: str

0 commit comments

Comments
 (0)