1+ import re
12import subprocess
23import sys
34
@@ -75,10 +76,23 @@ def evaluate_formula(formula_template: str, local_rank: int) -> int:
7576
7677 This assumes running on the rank where the binding is asked for.
7778 """
78- formula = str (formula_template ).replace ("$local_rank" , str (local_rank ))
79- command = f'echo "{ formula } "'
79+ processed_formula = str (formula_template )
80+ substitutions = re .findall (r"\{\{([^}]+)\}\}" , processed_formula )
81+ for command_to_run in set (substitutions ):
82+ try :
83+ result = subprocess .run (
84+ command_to_run , shell = True , capture_output = True , text = True , check = True
85+ )
86+ placeholder = f"{{{{{ command_to_run } }}}}"
87+ processed_formula = processed_formula .replace (placeholder , result .stdout .strip ())
88+ except subprocess .CalledProcessError as e :
89+ raise RuntimeError (f"Error executing sub-command '{ command_to_run } ': { e } " )
90+
91+ # Substitute local_rank and evaluate final expression
92+ final_expression = processed_formula .replace ("$local_rank" , str (local_rank ))
93+ command = f'echo "{ final_expression } "'
8094 result = subprocess .run (command , shell = True , capture_output = True , text = True , check = True )
81- return int ( result .stdout .strip () )
95+ return result .stdout .strip ()
8296
8397 def find_matching_rule (self , rank : int , node_id : int ) -> dict :
8498 """
@@ -105,45 +119,49 @@ def find_matching_rule(self, rank: int, node_id: int) -> dict:
105119 return item ["default" ]
106120 return None
107121
108- def get_gpu_binding_for_rank (self , hwloc_type , local_rank ):
122+ def get_gpu_binding_for_rank (self , on_domain , hwloc_type , local_rank ):
109123 """
110124 Get a GPU binding for a rank. Local means a numa node close by, remote not.
111125 """
112126 if not self .gpu_objects :
113- raise RuntimeError (
114- "Shape type is GPU-aware, but no GPUs were discovered on the system."
115- )
116-
127+ raise RuntimeError ("Shape is GPU-aware, but no GPUs were discovered." )
117128 if local_rank >= len (self .gpu_objects ):
118129 raise IndexError (
119130 f"local_rank { local_rank } is out of range for { len (self .gpu_objects )} GPUs."
120131 )
121132
122133 my_gpu_object = self .gpu_objects [local_rank ]
123134
124- # Get PCI Bus ID for CUDA_VISIBLE_DEVICES
125135 pci_bus_id_cmd = f"hwloc-pci-lookup { my_gpu_object } "
126- pci_res = subprocess .run (
136+ cuda_devices = subprocess .run (
127137 pci_bus_id_cmd , shell = True , capture_output = True , text = True , check = True
128- )
129- cuda_devices = pci_res .stdout .strip ()
138+ ).stdout .strip ()
130139
131- # Find the local NUMA node for this GPU
132140 local_numa_cmd = f"hwloc-calc { my_gpu_object } --ancestor numa -I"
133- numa_res = subprocess .run (
134- local_numa_cmd , shell = True , capture_output = True , text = True , check = True
141+ local_numa_id = int (
142+ subprocess .run (
143+ local_numa_cmd , shell = True , capture_output = True , text = True , check = True
144+ ).stdout .strip ()
135145 )
136- local_numa_id = int (numa_res .stdout .strip ())
137146
138- if hwloc_type == "gpu-local" :
139- binding_string = f"numa:{ local_numa_id } "
140-
141- # gpu-remote
142- else :
147+ target_numa_location = ""
148+ if on_domain == "gpu-local" :
149+ target_numa_location = f"numa:{ local_numa_id } "
150+ else : # gpu-remote
143151 remote_numa_id = (local_numa_id + 1 ) % self .num_numa
144- binding_string = f"numa:{ remote_numa_id } "
145-
146- # Return both the binding and the device
152+ target_numa_location = f"numa:{ remote_numa_id } "
153+
154+ # If the requested type is just 'numa', we're done.
155+ if hwloc_type == "numa" :
156+ return f"{ target_numa_location } ,{ cuda_devices } "
157+
158+ # Otherwise, find the first object of the requested type WITHIN that NUMA domain.
159+ # This is a powerful composition of the two concepts.
160+ # E.g., find the first 'core' on the 'gpu-local' NUMA domain.
161+ cmd = f"hwloc-calc { target_numa_location } --intersect { hwloc_type } --first"
162+ binding_string = subprocess .run (
163+ cmd , shell = True , capture_output = True , text = True , check = True
164+ ).stdout .strip ()
147165 return f"{ binding_string } ,{ cuda_devices } "
148166
149167 def get_binding_for_rank (self , rank : int , node_id : int , local_rank : int ) -> str :
@@ -168,8 +186,9 @@ def get_binding_for_rank(self, rank: int, node_id: int, local_rank: int) -> str:
168186 if hwloc_type is None :
169187 raise ValueError (f"Matching rule has no 'type' defined: { rule } " )
170188
171- if hwloc_type in ["gpu-local" , "gpu-remote" ]:
172- return self .get_gpu_binding_for_rank (hwloc_type , local_rank )
189+ on_domain = rule .get ("on" )
190+ if on_domain in ["gpu-local" , "gpu-remote" ]:
191+ return self .get_gpu_binding_for_rank (on_domain , hwloc_type , local_rank )
173192
174193 cpu_binding_string = self .get_cpu_binding (hwloc_type , rule , local_rank )
175194
0 commit comments