|
1 | 1 | import os |
2 | 2 | import re |
3 | | -import hashlib |
4 | 3 | import subprocess |
5 | 4 | import sysconfig |
6 | 5 | from abc import ABCMeta, abstractmethod |
7 | 6 | from dataclasses import dataclass |
8 | | -from typing import Dict, List, Tuple, Union |
| 7 | +from typing import Dict, Union |
9 | 8 | from types import ModuleType |
10 | | -from .._utils import find_paths_if |
11 | | - |
12 | | -# Table that associates strings to AttrsDescriptor (sub)classes. |
13 | | -# In this way we can dynamically select the correct class |
14 | | -# constructor |
15 | | -_descriptor_table = {} |
16 | | - |
17 | | - |
18 | | -def register_descriptor(cls): |
19 | | - """ |
20 | | - Register a descriptor into the descriptor table |
21 | | - """ |
22 | | - _descriptor_table[cls.__name__] = cls |
23 | | - return cls |
24 | | - |
25 | | - |
26 | | -@register_descriptor |
27 | | -class AttrsDescriptor: |
28 | | - """ |
29 | | - This class handles compile-time properties for specific function parameters. |
30 | | -
|
31 | | - Different backends can add more properties to the common ones. The class |
32 | | - contains two fields: |
33 | | -
|
34 | | - `arg_properties`: a dictionary containing the different compile-time properties for different |
35 | | - parameters. I.e., the dictionary is a map from property names to parameter indices |
36 | | - { |
37 | | - "prop0": (0, 2, 3) |
38 | | - "prop1": (0, 4, 5) |
39 | | - } |
40 | | - Different backends might need different properties on those paraemters to enable |
41 | | - specific optimizations. The common compile time properties contained in this class |
42 | | - are : |
43 | | - - "tt.divisibility", i.e., is the given parameter divisible by 16 |
44 | | - - "tt.equal_to_1", i.e., is the given parameter an integer constant 1 |
45 | | -
|
46 | | - `property_values`: a dictionary containing the value of the different compile-time properties, like: |
47 | | - { |
48 | | - "prop0": val0 |
49 | | - "prop1": val1 |
50 | | - } |
51 | | -
|
52 | | - `constant_properties`: a set containing the properties that can be used to determine if a parameter is constant |
53 | | -
|
54 | | - """ |
55 | | - __slots__ = ('divisibility_16', 'equal_to_1', 'equal_to_none', 'arg_properties', 'property_values', |
56 | | - 'constant_properties') |
57 | | - |
58 | | - def __init__(self, params=None, values=None): |
59 | | - """ |
60 | | - Initialize the compile-time properties |
61 | | -
|
62 | | - We can initialize the AttrsDescriptor class by passing the list of params |
63 | | - of the function and their `values`. The function will try to apply the properties |
64 | | - to the values and save the parameters in the `arg_properties` list. If we don't pass |
65 | | - either the `params` or the `values` we should initialize the class via an alternative method |
66 | | - (see `from_dict` or `from_hints`) |
67 | | - """ |
68 | | - # Default initialization |
69 | | - self.arg_properties = {} |
70 | | - self.property_values = {} |
71 | | - self.equal_to_none = {} |
72 | | - self.constant_properties = set() |
73 | | - |
74 | | - self._add_common_properties(params, values) |
75 | | - self._add_backend_properties(params, values) |
76 | | - self._init_slots() |
77 | | - |
78 | | - def _add_common_properties(self, params, values): |
79 | | - """ Add common compile-time properties """ |
80 | | - self.property_values["tt.divisibility"] = 16 |
81 | | - self.property_values["tt.equal_to"] = 1 |
82 | | - self.constant_properties.add("tt.equal_to") |
83 | | - |
84 | | - if (params is None) or (values is None): |
85 | | - return |
86 | | - |
87 | | - # Compile properties deduction |
88 | | - assert (len(params) == len(values)) |
89 | | - |
90 | | - # Divisibility property |
91 | | - divisibility_16 = [] |
92 | | - for param, arg in zip(params, values): |
93 | | - if param.do_not_specialize or \ |
94 | | - param.do_not_specialize_on_alignment: |
95 | | - continue |
96 | | - paths = find_paths_if(arg, lambda path, val: AttrsDescriptor.is_divisible_by_16(val)) |
97 | | - divisibility_16 += [(param.num, ) + x for x in paths] |
98 | | - self.arg_properties["tt.divisibility"] = divisibility_16 |
99 | | - |
100 | | - # Equal to 1 property |
101 | | - equal_to_1 = [] |
102 | | - for param, arg in zip(params, values): |
103 | | - if param.do_not_specialize: |
104 | | - continue |
105 | | - paths = find_paths_if(arg, lambda path, val: AttrsDescriptor.is_equal_to_1(val)) |
106 | | - equal_to_1 += [(param.num, ) + x for x in paths] |
107 | | - self.arg_properties["tt.equal_to"] = equal_to_1 |
108 | | - |
109 | | - # Equal to None property |
110 | | - equal_to_none = [] |
111 | | - for param, arg in zip(params, values): |
112 | | - paths = find_paths_if(arg, lambda path, val: val is None) |
113 | | - equal_to_none += [(param.num, ) + x for x in paths] |
114 | | - self.equal_to_none = equal_to_none |
115 | | - |
116 | | - def _add_backend_properties(self, params=None, values=None): |
117 | | - """ This method is for different subclasses to implement their own compile-time properties """ |
118 | | - pass |
119 | | - |
120 | | - def _init_slots(self): |
121 | | - """ Initialize the slots of this class """ |
122 | | - for name, val in self.arg_properties.items(): |
123 | | - setattr(self, name.removeprefix('tt.') + '_' + str(self.property_values[name]), val) |
124 | | - |
125 | | - def get_fn_attrs(self) -> Dict: |
126 | | - """ |
127 | | - Get the function attributes as a dictionary. |
128 | | -
|
129 | | - The returned dictionary will look like : |
130 | | - { |
131 | | - "arg0" : [(prop_name00, val00), (prop_name01, val01), ...)]} |
132 | | - "arg1" : [(prop_name10, val10), (prop_name11, val11), ...)]} |
133 | | - } |
134 | | - """ |
135 | | - attrs = {} |
136 | | - for prop_name, arg_set in self.arg_properties.items(): |
137 | | - prop_val = self.property_values[prop_name] |
138 | | - for arg in arg_set: |
139 | | - attrs[arg] = attrs.get(arg, []) + [(prop_name, prop_val)] |
140 | | - return attrs |
141 | | - |
142 | | - def get_constants(self) -> Dict: |
143 | | - """ Return a mapping of constant parameters to their values """ |
144 | | - constants = {} |
145 | | - for prop_name in self.constant_properties: |
146 | | - for p in self.arg_properties.get(prop_name, []): |
147 | | - constants[p] = self.property_values[prop_name] |
148 | | - for v in self.equal_to_none: |
149 | | - constants[v] = None |
150 | | - return constants |
151 | | - |
152 | | - def filter_out_constants(self): |
153 | | - """ Return the same object, without properties marked as constants""" |
154 | | - import copy |
155 | | - c = copy.deepcopy(self) |
156 | | - for prop_name in c.constant_properties: |
157 | | - c.arg_properties.pop(prop_name, None) |
158 | | - c.property_values.pop(prop_name, None) |
159 | | - c.constant_properties = {} |
160 | | - return c |
161 | | - |
162 | | - def hash(self): |
163 | | - values = [sorted(self.arg_properties.values())] |
164 | | - values += [sorted(self.property_values.values())] |
165 | | - values += [sorted(self.constant_properties)] |
166 | | - key = str(values) |
167 | | - return hashlib.sha256(key.encode("utf-8")).hexdigest() |
168 | | - |
169 | | - def to_dict(self): |
170 | | - """ |
171 | | - Store the fields of this class in a serializable dictionary |
172 | | - """ |
173 | | - # We need to only store the `arg_properties` field. To initialize the |
174 | | - # other fields we relay on the class type. We store it as a string in |
175 | | - # the dictionary so that we can use it to invoke the appropriate |
176 | | - # (sub)class constructor in the `from_dict` method. |
177 | | - return {"arg_properties": self.arg_properties, "cls": type(self).__name__} |
178 | | - |
179 | | - @staticmethod |
180 | | - def from_dict(data): |
181 | | - """ |
182 | | - Create the object from a serializable dictionary |
183 | | - """ |
184 | | - attrs_descriptor = _descriptor_table[data["cls"]]() |
185 | | - for prop_name, param_ids in data["arg_properties"].items(): |
186 | | - attrs_descriptor.arg_properties[prop_name] = list(map(tuple, param_ids)) |
187 | | - attrs_descriptor._init_slots() |
188 | | - return attrs_descriptor |
189 | | - |
190 | | - @classmethod |
191 | | - def from_hints(cls, hints: List[Tuple[int, int]]): |
192 | | - """ |
193 | | - Create the class from a set of hints that are passed in. |
194 | | -
|
195 | | - Instead of deducing the properties from a list of paramaters and values, |
196 | | - the user can pass in a list of `hints=[(param_index, val)]` and if `val` |
197 | | - matches one of the values of the properties (e.g., `prop_val[prop0]`), |
198 | | - then we insert `param_index` into the correct list (e.g., in |
199 | | - `arg_properties[prop0]`) |
200 | | - """ |
201 | | - attrs_descriptor = cls() |
202 | | - for prop_name, prop_val in attrs_descriptor.property_values.items(): |
203 | | - attrs_descriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val] |
204 | | - attrs_descriptor._init_slots() |
205 | | - return attrs_descriptor |
206 | | - |
207 | | - @staticmethod |
208 | | - def is_divisible_by_16(x): |
209 | | - """ Return if the argument is a multiple of 16""" |
210 | | - if hasattr(x, "data_ptr"): |
211 | | - return x.data_ptr() % 16 == 0 |
212 | | - elif isinstance(x, int): |
213 | | - return x % 16 == 0 |
214 | | - if x is None: |
215 | | - return True |
216 | | - return False |
217 | | - |
218 | | - @staticmethod |
219 | | - def is_equal_to_1(x): |
220 | | - """ Return if the argument is a constant 1""" |
221 | | - return True if isinstance(x, int) and not isinstance(x, bool) and x == 1 else False |
222 | | - |
223 | | - @staticmethod |
224 | | - def get_property_key(val, align): |
225 | | - if align and AttrsDescriptor.is_divisible_by_16(val): |
226 | | - return "D" |
227 | | - if AttrsDescriptor.is_equal_to_1(val): |
228 | | - return "1" |
229 | | - return "N" |
230 | | - |
231 | | - def __repr__(self): |
232 | | - return f"AttrsDescriptor.from_dict({self.to_dict()!r})" |
233 | 9 |
|
234 | 10 |
|
235 | 11 | @dataclass(frozen=True) |
@@ -308,16 +84,21 @@ def get_module_map(self) -> Dict[str, ModuleType]: |
308 | 84 | """ |
309 | 85 | raise NotImplementedError |
310 | 86 |
|
311 | | - def get_attrs_descriptor(self, params, args): |
312 | | - """ |
313 | | - Return an attribute descriptor: given a set of parameters and arguments |
314 | | - the descriptor stores a set of compile time properties that can improve code |
315 | | - generation. Different backends might benefit from different properties |
316 | | - """ |
317 | | - return AttrsDescriptor(params, args) |
| 87 | + @staticmethod |
| 88 | + def parse_attr(desc): |
| 89 | + assert isinstance(desc, str) |
| 90 | + ret = [] |
| 91 | + if "D" in desc: |
| 92 | + ret += [["tt.divisibility", 16]] |
| 93 | + return ret |
318 | 94 |
|
319 | | - def compute_spec_key(self, arg, align): |
| 95 | + @staticmethod |
| 96 | + def get_arg_specialization(arg, ty, **kwargs): |
320 | 97 | """ |
321 | | - Return the ascii key for a given argument with a given set of properties |
| 98 | + Return a string unique to each possible specialization of the argument |
322 | 99 | """ |
323 | | - return AttrsDescriptor.get_property_key(arg, align) |
| 100 | + if ty == "int" and arg % 16 == 0 and kwargs.get("align", False): |
| 101 | + return "D" |
| 102 | + if ty == "tensor" and arg.data_ptr() % 16 == 0 and kwargs.get("align", False): |
| 103 | + return "D" |
| 104 | + return "" |
0 commit comments