|
1 | 1 | import os |
2 | 2 | import re |
| 3 | +import hashlib |
3 | 4 | import subprocess |
4 | 5 |
|
5 | 6 | from abc import ABCMeta, abstractmethod, abstractclassmethod |
|
8 | 9 | from types import ModuleType |
9 | 10 |
|
10 | 11 |
|
| 12 | +class AttrsDescriptor: |
| 13 | + """ |
| 14 | + This class handles compile-time properties for specific function parameters. |
| 15 | +
|
| 16 | + Different backends can add more properties to the common ones. The class |
| 17 | + contains two fields: |
| 18 | +
|
| 19 | + `arg_properties`: a dictionary containing the different compile-time properties for different |
| 20 | + parameters. I.e., the dictionary is a map from property names to parameter indices |
| 21 | + { |
| 22 | + "prop0": (0, 2, 3) |
| 23 | + "prop1": (0, 4, 5) |
| 24 | + } |
| 25 | + Different backends might need different properties on those paraemters to enable |
| 26 | + specific optimizations. The common compile time properties contained in this class |
| 27 | + are : |
| 28 | + - "tt.divisibility", i.e., is the given parameter divisible by 16 |
| 29 | + - "tt.equal_to_1", i.e., is the given parameter an integer constant 1 |
| 30 | +
|
| 31 | + `property_values`: a dictionary containing the value of the different compile-time properties, like: |
| 32 | + { |
| 33 | + "prop0": val0 |
| 34 | + "prop1": val1 |
| 35 | + } |
| 36 | +
|
| 37 | + `constant_properties`: a set containing the properties that can be used to determine if a parameter is constant |
| 38 | +
|
| 39 | + """ |
| 40 | + __slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties') |
| 41 | + |
| 42 | + def __init__(self, params=None, values=None): |
| 43 | + """ |
| 44 | + Initialize the compile-time properties |
| 45 | +
|
| 46 | + We can initialize the AttrsDescriptor class by passing the list of params |
| 47 | + of the function and their `values`. The function will try to apply the properties |
| 48 | + to the values and save the parameters in the `arg_properties` list. If we don't pass |
| 49 | + either the `params` or the `values` we should initialize the class via an alternative method |
| 50 | + (see `from_dict` or `from_hints`) |
| 51 | + """ |
| 52 | + # Default initialization |
| 53 | + self.arg_properties = {} |
| 54 | + self.property_values = {} |
| 55 | + self.constant_properties = set() |
| 56 | + |
| 57 | + self._add_common_properties(params, values) |
| 58 | + self._add_backend_properties(params, values) |
| 59 | + self._init_slots() |
| 60 | + |
| 61 | + def _add_common_properties(self, params, values): |
| 62 | + """ Add common compile-time properties """ |
| 63 | + self.property_values["tt.divisibility"] = 16 |
| 64 | + self.property_values["tt.equal_to"] = 1 |
| 65 | + self.constant_properties.add("tt.equal_to") |
| 66 | + |
| 67 | + if (params is None) or (values is None): |
| 68 | + return |
| 69 | + |
| 70 | + # Compile properties deduction |
| 71 | + assert (len(params) == len(values)) |
| 72 | + |
| 73 | + # Divisibility property |
| 74 | + self.arg_properties["tt.divisibility"] = [ |
| 75 | + param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg) |
| 76 | + and not param.do_not_specialize and not param.do_not_specialize_on_alignment |
| 77 | + ] |
| 78 | + |
| 79 | + # Equal to 1 property |
| 80 | + self.arg_properties["tt.equal_to"] = [ |
| 81 | + param.num |
| 82 | + for param, arg in zip(params, values) |
| 83 | + if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize |
| 84 | + ] |
| 85 | + |
| 86 | + def _add_backend_properties(self, params=None, values=None): |
| 87 | + """ This method is for different subclasses to implement their own compile-time properties """ |
| 88 | + pass |
| 89 | + |
| 90 | + def _init_slots(self): |
| 91 | + """ Initialize the slots of this class """ |
| 92 | + for name, val in self.arg_properties.items(): |
| 93 | + setattr(self, name.removeprefix('tt.') + '_' + str(self.property_values[name]), val) |
| 94 | + |
| 95 | + def get_fn_attrs(self) -> Dict: |
| 96 | + """ |
| 97 | + Get the function attributes as a dictionary. |
| 98 | +
|
| 99 | + The returned dictionary will look like : |
| 100 | + { |
| 101 | + "arg0" : [(prop_name00, val00), (prop_name01, val01), ...)]} |
| 102 | + "arg1" : [(prop_name10, val10), (prop_name11, val11), ...)]} |
| 103 | + } |
| 104 | + """ |
| 105 | + attrs = {} |
| 106 | + for prop_name, arg_set in self.arg_properties.items(): |
| 107 | + prop_val = self.property_values[prop_name] |
| 108 | + for arg in arg_set: |
| 109 | + attrs[arg] = attrs.get(arg, []) + [(prop_name, prop_val)] |
| 110 | + return attrs |
| 111 | + |
| 112 | + def get_constants(self) -> Dict: |
| 113 | + """ Return a mapping of constant parameters to their values """ |
| 114 | + constants = {} |
| 115 | + for prop_name in self.constant_properties: |
| 116 | + for p in self.arg_properties.get(prop_name, []): |
| 117 | + constants[p] = self.property_values[prop_name] |
| 118 | + return constants |
| 119 | + |
| 120 | + def filter_out_constants(self): |
| 121 | + """ Return the same object, without properties marked as constants""" |
| 122 | + import copy |
| 123 | + c = copy.deepcopy(self) |
| 124 | + for prop_name in c.constant_properties: |
| 125 | + c.arg_properties.pop(prop_name, None) |
| 126 | + c.property_values.pop(prop_name, None) |
| 127 | + c.constant_properties = {} |
| 128 | + return c |
| 129 | + |
| 130 | + def hash(self): |
| 131 | + values = [sorted(self.arg_properties.values())] |
| 132 | + values += [sorted(self.property_values.values())] |
| 133 | + values += [sorted(self.constant_properties)] |
| 134 | + key = str(values) |
| 135 | + return hashlib.sha256(key.encode("utf-8")).hexdigest() |
| 136 | + |
| 137 | + def to_dict(self): |
| 138 | + return self.arg_properties |
| 139 | + |
| 140 | + @staticmethod |
| 141 | + def from_dict(data): |
| 142 | + attrsDescriptor = AttrsDescriptor() |
| 143 | + for prop_name, param_ids in data.items(): |
| 144 | + attrsDescriptor.arg_properties[prop_name] = param_ids |
| 145 | + attrsDescriptor._init_slots() |
| 146 | + return attrsDescriptor |
| 147 | + |
| 148 | + @staticmethod |
| 149 | + def from_hints(hints: list[tuple[int, int]]): |
| 150 | + """ |
| 151 | + Create the class from a set of hints that are passed in. |
| 152 | +
|
| 153 | + Instead of deducing the properties from a list of paramaters and values, |
| 154 | + the user can pass in a list of `hints=[(param_index, val)]` and if `val` |
| 155 | + matches one of the values of the properties (e.g., `prop_val[prop0]`), |
| 156 | + then we insert `param_index` into the correct list (e.g., in |
| 157 | + `arg_properties[prop0]`) |
| 158 | + """ |
| 159 | + attrsDescriptor = AttrsDescriptor() |
| 160 | + for prop_name, prop_val in attrsDescriptor.property_values.items(): |
| 161 | + attrsDescriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val] |
| 162 | + attrsDescriptor._init_slots() |
| 163 | + return attrsDescriptor |
| 164 | + |
| 165 | + @staticmethod |
| 166 | + def is_divisible_by_16(x): |
| 167 | + """ Return if the argument is a multiple of 16""" |
| 168 | + if hasattr(x, "data_ptr"): |
| 169 | + return x.data_ptr() % 16 == 0 |
| 170 | + elif isinstance(x, int): |
| 171 | + return x % 16 == 0 |
| 172 | + if x is None: |
| 173 | + return True |
| 174 | + return False |
| 175 | + |
| 176 | + @staticmethod |
| 177 | + def is_equal_to_1(x): |
| 178 | + """ Return if the argument is a constant 1""" |
| 179 | + return True if isinstance(x, int) and not isinstance(x, bool) and x == 1 else False |
| 180 | + |
| 181 | + @staticmethod |
| 182 | + def get_property_key(val, align): |
| 183 | + if align and AttrsDescriptor.is_divisible_by_16(val): |
| 184 | + return "D" |
| 185 | + if AttrsDescriptor.is_equal_to_1(val): |
| 186 | + return "1" |
| 187 | + return "N" |
| 188 | + |
| 189 | + def __repr__(self): |
| 190 | + return f"AttrsDescriptor.from_dict({self.arg_properties})" |
| 191 | + |
| 192 | + |
11 | 193 | @dataclass(frozen=True) |
12 | 194 | class GPUTarget(object): |
13 | 195 | # Target backend, e.g., cuda, hip |
@@ -79,6 +261,20 @@ def load_dialects(self, context): |
79 | 261 | @abstractmethod |
80 | 262 | def get_module_map(self) -> Dict[str, ModuleType]: |
81 | 263 | """ |
82 | | - Return a map of interface modules to their device-specific implementations. |
| 264 | + Return a map of interface modules to their device-specific implementations |
83 | 265 | """ |
84 | 266 | raise NotImplementedError |
| 267 | + |
| 268 | + def get_attrs_descriptor(self, params, args): |
| 269 | + """ |
| 270 | + Return an attribute descriptor: given a set of parameters and arguments |
| 271 | + the descriptor stores a set of compile time properties that can improve code |
| 272 | + generation. Different backends might benefit from different properties |
| 273 | + """ |
| 274 | + return AttrsDescriptor(params, args) |
| 275 | + |
| 276 | + def compute_spec_key(self, arg, align): |
| 277 | + """ |
| 278 | + Return the ascii key for a given argument with a given set of properties |
| 279 | + """ |
| 280 | + return AttrsDescriptor.get_property_key(arg, align) |
0 commit comments