|
1 | 1 | import os |
2 | 2 | import re |
3 | | -import hashlib |
4 | 3 | import subprocess |
5 | 4 |
|
6 | 5 | from abc import ABCMeta, abstractmethod, abstractclassmethod |
|
9 | 8 | from types import ModuleType |
10 | 9 |
|
11 | 10 |
|
12 | | -class AttrsDescriptor: |
13 | | - """ |
14 | | - This class handles compile-time properties for specific function parameters. |
15 | | -
|
16 | | - Different backends can add more properties to the common ones. The class |
17 | | - contains two fields: |
18 | | -
|
19 | | - `arg_properties`: a dictionary containing the different compile-time properties for different |
20 | | - parameters. I.e., the dictionary is a map from property names to parameter indices |
21 | | - { |
22 | | - "prop0": (0, 2, 3) |
23 | | - "prop1": (0, 4, 5) |
24 | | - } |
25 | | - Different backends might need different properties on those paraemters to enable |
26 | | - specific optimizations. The common compile time properties contained in this class |
27 | | - are : |
28 | | - - "tt.divisibility", i.e., is the given parameter divisible by 16 |
29 | | - - "tt.equal_to_1", i.e., is the given parameter an integer constant 1 |
30 | | -
|
31 | | - `property_values`: a dictionary containing the value of the different compile-time properties, like: |
32 | | - { |
33 | | - "prop0": val0 |
34 | | - "prop1": val1 |
35 | | - } |
36 | | -
|
37 | | - `constant_properties`: a set containing the properties that can be used to determine if a parameter is constant |
38 | | -
|
39 | | - """ |
40 | | - __slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties') |
41 | | - |
42 | | - def __init__(self, params=None, values=None): |
43 | | - """ |
44 | | - Initialize the compile-time properties |
45 | | -
|
46 | | - We can initialize the AttrsDescriptor class by passing the list of params |
47 | | - of the function and their `values`. The function will try to apply the properties |
48 | | - to the values and save the parameters in the `arg_properties` list. If we don't pass |
49 | | - either the `params` or the `values` we should initialize the class via an alternative method |
50 | | - (see `from_dict` or `from_hints`) |
51 | | - """ |
52 | | - # Default initialization |
53 | | - self.arg_properties = {} |
54 | | - self.property_values = {} |
55 | | - self.constant_properties = set() |
56 | | - |
57 | | - self._add_common_properties(params, values) |
58 | | - self._add_backend_properties(params, values) |
59 | | - self._init_slots() |
60 | | - |
61 | | - def _add_common_properties(self, params, values): |
62 | | - """ Add common compile-time properties """ |
63 | | - self.property_values["tt.divisibility"] = 16 |
64 | | - self.property_values["tt.equal_to"] = 1 |
65 | | - self.constant_properties.add("tt.equal_to") |
66 | | - |
67 | | - if (params is None) or (values is None): |
68 | | - return |
69 | | - |
70 | | - # Compile properties deduction |
71 | | - assert (len(params) == len(values)) |
72 | | - |
73 | | - # Divisibility property |
74 | | - self.arg_properties["tt.divisibility"] = [ |
75 | | - param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg) |
76 | | - and not param.do_not_specialize and not param.do_not_specialize_on_alignment |
77 | | - ] |
78 | | - |
79 | | - # Equal to 1 property |
80 | | - self.arg_properties["tt.equal_to"] = [ |
81 | | - param.num |
82 | | - for param, arg in zip(params, values) |
83 | | - if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize |
84 | | - ] |
85 | | - |
86 | | - def _add_backend_properties(self, params=None, values=None): |
87 | | - """ This method is for different subclasses to implement their own compile-time properties """ |
88 | | - pass |
89 | | - |
90 | | - def _init_slots(self): |
91 | | - """ Initialize the slots of this class """ |
92 | | - for name, val in self.arg_properties.items(): |
93 | | - setattr(self, name.removeprefix('tt.') + '_' + str(self.property_values[name]), val) |
94 | | - |
95 | | - def get_fn_attrs(self) -> Dict: |
96 | | - """ |
97 | | - Get the function attributes as a dictionary. |
98 | | -
|
99 | | - The returned dictionary will look like : |
100 | | - { |
101 | | - "arg0" : [(prop_name00, val00), (prop_name01, val01), ...)]} |
102 | | - "arg1" : [(prop_name10, val10), (prop_name11, val11), ...)]} |
103 | | - } |
104 | | - """ |
105 | | - attrs = {} |
106 | | - for prop_name, arg_set in self.arg_properties.items(): |
107 | | - prop_val = self.property_values[prop_name] |
108 | | - for arg in arg_set: |
109 | | - attrs[arg] = attrs.get(arg, []) + [(prop_name, prop_val)] |
110 | | - return attrs |
111 | | - |
112 | | - def get_constants(self) -> Dict: |
113 | | - """ Return a mapping of constant parameters to their values """ |
114 | | - constants = {} |
115 | | - for prop_name in self.constant_properties: |
116 | | - for p in self.arg_properties.get(prop_name, []): |
117 | | - constants[p] = self.property_values[prop_name] |
118 | | - return constants |
119 | | - |
120 | | - def filter_out_constants(self): |
121 | | - """ Return the same object, without properties marked as constants""" |
122 | | - import copy |
123 | | - c = copy.deepcopy(self) |
124 | | - for prop_name in c.constant_properties: |
125 | | - c.arg_properties.pop(prop_name, None) |
126 | | - c.property_values.pop(prop_name, None) |
127 | | - c.constant_properties = {} |
128 | | - return c |
129 | | - |
130 | | - def hash(self): |
131 | | - values = [sorted(self.arg_properties.values())] |
132 | | - values += [sorted(self.property_values.values())] |
133 | | - values += [sorted(self.constant_properties)] |
134 | | - key = str(values) |
135 | | - return hashlib.sha256(key.encode("utf-8")).hexdigest() |
136 | | - |
137 | | - def to_dict(self): |
138 | | - return self.arg_properties |
139 | | - |
140 | | - @staticmethod |
141 | | - def from_dict(data): |
142 | | - attrsDescriptor = AttrsDescriptor() |
143 | | - for prop_name, param_ids in data.items(): |
144 | | - attrsDescriptor.arg_properties[prop_name] = param_ids |
145 | | - attrsDescriptor._init_slots() |
146 | | - return attrsDescriptor |
147 | | - |
148 | | - @staticmethod |
149 | | - def from_hints(hints: list[tuple[int, int]]): |
150 | | - """ |
151 | | - Create the class from a set of hints that are passed in. |
152 | | -
|
153 | | - Instead of deducing the properties from a list of paramaters and values, |
154 | | - the user can pass in a list of `hints=[(param_index, val)]` and if `val` |
155 | | - matches one of the values of the properties (e.g., `prop_val[prop0]`), |
156 | | - then we insert `param_index` into the correct list (e.g., in |
157 | | - `arg_properties[prop0]`) |
158 | | - """ |
159 | | - attrsDescriptor = AttrsDescriptor() |
160 | | - for prop_name, prop_val in attrsDescriptor.property_values.items(): |
161 | | - attrsDescriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val] |
162 | | - attrsDescriptor._init_slots() |
163 | | - return attrsDescriptor |
164 | | - |
165 | | - @staticmethod |
166 | | - def is_divisible_by_16(x): |
167 | | - """ Return if the argument is a multiple of 16""" |
168 | | - if hasattr(x, "data_ptr"): |
169 | | - return x.data_ptr() % 16 == 0 |
170 | | - elif isinstance(x, int): |
171 | | - return x % 16 == 0 |
172 | | - if x is None: |
173 | | - return True |
174 | | - return False |
175 | | - |
176 | | - @staticmethod |
177 | | - def is_equal_to_1(x): |
178 | | - """ Return if the argument is a constant 1""" |
179 | | - return True if isinstance(x, int) and not isinstance(x, bool) and x == 1 else False |
180 | | - |
181 | | - @staticmethod |
182 | | - def get_property_key(val, align): |
183 | | - if align and AttrsDescriptor.is_divisible_by_16(val): |
184 | | - return "D" |
185 | | - if AttrsDescriptor.is_equal_to_1(val): |
186 | | - return "1" |
187 | | - return "N" |
188 | | - |
189 | | - |
190 | 11 | @dataclass(frozen=True) |
191 | 12 | class GPUTarget(object): |
192 | 13 | # Target backend, e.g., cuda, hip |
@@ -258,20 +79,6 @@ def load_dialects(self, context): |
258 | 79 | @abstractmethod |
259 | 80 | def get_module_map(self) -> Dict[str, ModuleType]: |
260 | 81 | """ |
261 | | - Return a map of interface modules to their device-specific implementations |
| 82 | + Return a map of interface modules to their device-specific implementations. |
262 | 83 | """ |
263 | 84 | raise NotImplementedError |
264 | | - |
265 | | - def get_attrs_descriptor(self, params, args): |
266 | | - """ |
267 | | - Return an attribute descriptor: given a set of parameters and arguments |
268 | | - the descriptor stores a set of compile time properties that can improve code |
269 | | - generation. Different backends might benefit from different properties |
270 | | - """ |
271 | | - return AttrsDescriptor(params, args) |
272 | | - |
273 | | - def compute_spec_key(self, arg, align): |
274 | | - """ |
275 | | - Return the ascii key for a given argument with a given set of properties |
276 | | - """ |
277 | | - return AttrsDescriptor.get_property_key(arg, align) |
0 commit comments