Skip to content

Commit 28d8e05

Browse files
committed
gpu: use expression syntax
Signed-off-by: vsoch <[email protected]>
1 parent 993a3ad commit 28d8e05

File tree

1 file changed

+45
-26
lines changed

1 file changed

+45
-26
lines changed

fluxbind/shape/shape.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
import subprocess
23
import sys
34

@@ -75,10 +76,23 @@ def evaluate_formula(formula_template: str, local_rank: int) -> int:
7576
7677
This assumes running on the rank where the binding is asked for.
7778
"""
78-
formula = str(formula_template).replace("$local_rank", str(local_rank))
79-
command = f'echo "{formula}"'
79+
processed_formula = str(formula_template)
80+
substitutions = re.findall(r"\{\{([^}]+)\}\}", processed_formula)
81+
for command_to_run in set(substitutions):
82+
try:
83+
result = subprocess.run(
84+
command_to_run, shell=True, capture_output=True, text=True, check=True
85+
)
86+
placeholder = f"{{{{{command_to_run}}}}}"
87+
processed_formula = processed_formula.replace(placeholder, result.stdout.strip())
88+
except subprocess.CalledProcessError as e:
89+
raise RuntimeError(f"Error executing sub-command '{command_to_run}': {e}")
90+
91+
# Substitute local_rank and evaluate final expression
92+
final_expression = processed_formula.replace("$local_rank", str(local_rank))
93+
command = f'echo "{final_expression}"'
8094
result = subprocess.run(command, shell=True, capture_output=True, text=True, check=True)
81-
return int(result.stdout.strip())
95+
return result.stdout.strip()
8296

8397
def find_matching_rule(self, rank: int, node_id: int) -> dict:
8498
"""
@@ -105,45 +119,49 @@ def find_matching_rule(self, rank: int, node_id: int) -> dict:
105119
return item["default"]
106120
return None
107121

108-
def get_gpu_binding_for_rank(self, hwloc_type, local_rank):
122+
def get_gpu_binding_for_rank(self, on_domain, hwloc_type, local_rank):
109123
"""
110124
Get a GPU binding for a rank. Local means a numa node close by, remote not.
111125
"""
112126
if not self.gpu_objects:
113-
raise RuntimeError(
114-
"Shape type is GPU-aware, but no GPUs were discovered on the system."
115-
)
116-
127+
raise RuntimeError("Shape is GPU-aware, but no GPUs were discovered.")
117128
if local_rank >= len(self.gpu_objects):
118129
raise IndexError(
119130
f"local_rank {local_rank} is out of range for {len(self.gpu_objects)} GPUs."
120131
)
121132

122133
my_gpu_object = self.gpu_objects[local_rank]
123134

124-
# Get PCI Bus ID for CUDA_VISIBLE_DEVICES
125135
pci_bus_id_cmd = f"hwloc-pci-lookup {my_gpu_object}"
126-
pci_res = subprocess.run(
136+
cuda_devices = subprocess.run(
127137
pci_bus_id_cmd, shell=True, capture_output=True, text=True, check=True
128-
)
129-
cuda_devices = pci_res.stdout.strip()
138+
).stdout.strip()
130139

131-
# Find the local NUMA node for this GPU
132140
local_numa_cmd = f"hwloc-calc {my_gpu_object} --ancestor numa -I"
133-
numa_res = subprocess.run(
134-
local_numa_cmd, shell=True, capture_output=True, text=True, check=True
141+
local_numa_id = int(
142+
subprocess.run(
143+
local_numa_cmd, shell=True, capture_output=True, text=True, check=True
144+
).stdout.strip()
135145
)
136-
local_numa_id = int(numa_res.stdout.strip())
137146

138-
if hwloc_type == "gpu-local":
139-
binding_string = f"numa:{local_numa_id}"
140-
141-
# gpu-remote
142-
else:
147+
target_numa_location = ""
148+
if on_domain == "gpu-local":
149+
target_numa_location = f"numa:{local_numa_id}"
150+
else: # gpu-remote
143151
remote_numa_id = (local_numa_id + 1) % self.num_numa
144-
binding_string = f"numa:{remote_numa_id}"
145-
146-
# Return both the binding and the device
152+
target_numa_location = f"numa:{remote_numa_id}"
153+
154+
# If the requested type is just 'numa', we're done.
155+
if hwloc_type == "numa":
156+
return f"{target_numa_location},{cuda_devices}"
157+
158+
# Otherwise, find the first object of the requested type WITHIN that NUMA domain.
159+
# This is a powerful composition of the two concepts.
160+
# E.g., find the first 'core' on the 'gpu-local' NUMA domain.
161+
cmd = f"hwloc-calc {target_numa_location} --intersect {hwloc_type} --first"
162+
binding_string = subprocess.run(
163+
cmd, shell=True, capture_output=True, text=True, check=True
164+
).stdout.strip()
147165
return f"{binding_string},{cuda_devices}"
148166

149167
def get_binding_for_rank(self, rank: int, node_id: int, local_rank: int) -> str:
@@ -168,8 +186,9 @@ def get_binding_for_rank(self, rank: int, node_id: int, local_rank: int) -> str:
168186
if hwloc_type is None:
169187
raise ValueError(f"Matching rule has no 'type' defined: {rule}")
170188

171-
if hwloc_type in ["gpu-local", "gpu-remote"]:
172-
return self.get_gpu_binding_for_rank(hwloc_type, local_rank)
189+
on_domain = rule.get("on")
190+
if on_domain in ["gpu-local", "gpu-remote"]:
191+
return self.get_gpu_binding_for_rank(on_domain, hwloc_type, local_rank)
173192

174193
cpu_binding_string = self.get_cpu_binding(hwloc_type, rule, local_rank)
175194

0 commit comments

Comments
 (0)