|
1 | | -import shlex |
2 | 1 | import subprocess |
3 | 2 | import sys |
| 3 | +from itertools import zip_longest |
4 | 4 |
|
5 | 5 |
|
6 | 6 | class Command: |
@@ -29,64 +29,93 @@ def run(self, command, shell: bool = False): |
29 | 29 | class HwlocCalcCommand(Command): |
30 | 30 | name = "hwloc-calc" |
31 | 31 |
|
| 32 | + def _parse_cpuset_to_list(self, cpuset_str: str) -> list[int]: |
| 33 | + """ |
| 34 | + Convert a potentially comma-separated hex string into a list of integers. |
| 35 | + """ |
| 36 | + if not cpuset_str or cpuset_str.lower() in ["0x0", "0"]: |
| 37 | + return [0] |
| 38 | + return [int(chunk, 16) for chunk in cpuset_str.strip().split(",")] |
| 39 | + |
| 40 | + def _operate_on_lists(self, list_a: list[int], list_b: list[int], operator: str) -> list[int]: |
| 41 | + """ |
| 42 | + Perform a bitwise operation on two lists of cpuset integers. |
| 43 | + """ |
| 44 | + max_len = max(len(list_a), len(list_b)) |
| 45 | + result_list = [] |
| 46 | + for i in range(max_len): |
| 47 | + val_a = list_a[i] if i < len(list_a) else 0 |
| 48 | + val_b = list_b[i] if i < len(list_b) else 0 |
| 49 | + |
| 50 | + if operator == "+": |
| 51 | + result_list.append(val_a | val_b) |
| 52 | + elif operator == "x": |
| 53 | + result_list.append(val_a & val_b) |
| 54 | + elif operator == "^": |
| 55 | + result_list.append(val_a ^ val_b) |
| 56 | + elif operator == "~": |
| 57 | + result_list.append(val_a & ~val_b) |
| 58 | + else: |
| 59 | + raise ValueError(f"Unsupported operator '{operator}'") |
| 60 | + return result_list |
| 61 | + |
32 | 62 | def count(self, hw_type: str, within: str = "machine:0") -> int: |
33 | 63 | """ |
34 | 64 | Returns the total number of a specific hardware object. |
35 | | -
|
36 | | - Args: |
37 | | - hw_type: The type of object to count (e.g., "core", "numa"). |
38 | | - within_object: Optional object to restrict the count to (e.g., "numa:0"). |
39 | 65 | """ |
40 | 66 | try: |
41 | 67 | args = ["--number-of", hw_type, within] |
42 | | - result_stdout = self.run([self.name] + args) |
| 68 | + result_stdout = self.run([self.name] + args, shell=False) |
43 | 69 | return int(result_stdout) |
44 | 70 | except (RuntimeError, ValueError) as e: |
45 | 71 | raise RuntimeError(f"Failed to count number of '{hw_type}': {e}") |
46 | 72 |
|
47 | 73 | def list_cpusets(self, hw_type: str, within: str = "machine:0") -> list[str]: |
48 | 74 | """ |
49 | 75 | Returns a list of cpuset strings for each object of a given type. |
50 | | -
|
51 | | - Args: |
52 | | - hw_type: The type of object to list (e.g., "numa"). |
53 | | - within_object: Optional object to restrict the list to. |
54 | 76 | """ |
55 | 77 | try: |
56 | | - # Get the indices of all objects of this type |
57 | 78 | args_intersect = ["--intersect", hw_type, within] |
58 | | - indices_str = self.run([self.name] + args_intersect) |
| 79 | + indices_str = self.run([self.name] + args_intersect, shell=False) |
59 | 80 | indices = indices_str.split(",") |
60 | | - |
61 | | - # Cut out early |
62 | 81 | if not indices or not indices[0]: |
63 | 82 | return [] |
64 | | - |
65 | | - # For each index, get its specific cpuset |
66 | | - return [self.run([self.name, f"{hw_type}:{i}"]) for i in indices] |
| 83 | + return [self.run([self.name, f"{hw_type}:{i}"], shell=False) for i in indices] |
67 | 84 | except (RuntimeError, ValueError) as e: |
68 | 85 | raise RuntimeError(f"Failed to list cpusets for '{hw_type}': {e}") |
69 | 86 |
|
70 | 87 | def get_cpuset(self, location: str) -> str: |
71 | 88 | """ |
72 | | - Gets the cpuset for a single, specific location string (e.g., "pci=...", "core:0"). |
| 89 | + Gets the cpuset for one or more space/operator-separated location strings. |
73 | 90 | """ |
74 | | - return self.run([self.name, location]) |
| 91 | + return self.run(f"{self.name} {location}", shell=True) |
75 | 92 |
|
76 | 93 | def get_object_in_set(self, cpuset: str, obj_type: str, index: int) -> str: |
77 | 94 | """ |
78 | 95 | Gets the Nth object of a type within a given cpuset. |
79 | | - e.g., find the 1st 'core' within cpuset '0x00ff'. |
80 | 96 | """ |
81 | | - # This uses the robust two-step process internally |
82 | | - all_objects_str = self.run([self.name, cpuset, "--intersect", obj_type]) |
83 | | - available_indices = all_objects_str.split(",") |
| 97 | + list_cmd = f"{self.name} '{cpuset}' --intersect {obj_type}" |
| 98 | + all_indices_str = self.run(list_cmd, shell=True) |
| 99 | + available_indices = all_indices_str.split(",") |
84 | 100 | try: |
85 | 101 | target_index = available_indices[index] |
86 | 102 | return f"{obj_type}:{target_index}" |
87 | 103 | except IndexError: |
88 | 104 | raise ValueError(f"Cannot find the {index}-th '{obj_type}' in cpuset {cpuset}.") |
89 | 105 |
|
| 106 | + def union_of_locations(self, locations: list[str]) -> str: |
| 107 | + """ |
| 108 | + Calculates the union of a list of hwloc location strings using Python logic. |
| 109 | + Returns a single, SPACE-separated string of hex cpusets. |
| 110 | + """ |
| 111 | + union_mask_list = [0] |
| 112 | + |
| 113 | + for loc in locations: |
| 114 | + loc_cpuset_str = self.get_cpuset(loc) |
| 115 | + loc_cpuset_list = self._parse_cpuset_to_list(loc_cpuset_str) |
| 116 | + union_mask_list = self._union_of_lists(union_mask_list, loc_cpuset_list) |
| 117 | + return " ".join([hex(chunk) for chunk in union_mask_list]) |
| 118 | + |
90 | 119 |
|
91 | 120 | class NvidiaSmiCommand(Command): |
92 | 121 | name = "nvidia-smi" |
|
0 commit comments