Skip to content

Commit c4ed65a

Browse files
Revert "Refactor compiler specializations to consider backend (#4734)"
This reverts commit cd1cc2d.
1 parent c4cc78f commit c4ed65a

File tree

8 files changed

+92
-229
lines changed

8 files changed

+92
-229
lines changed

python/test/unit/runtime/test_bindings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ def walk_fn(op):
5959
torch.empty((32, 32), device=device), # out_ptr
6060
16, # BLOCK_SIZE
6161
]
62-
target = triton.runtime.driver.active.get_current_target()
63-
backend = triton.compiler.compiler.make_backend(target)
6462
src = triton.compiler.compiler.ASTSource(
6563
fn=kernel,
6664
signature={
@@ -71,10 +69,12 @@ def walk_fn(op):
7169
constants={kernel.arg_names[i]: arg
7270
for i, arg in enumerate(args)
7371
if not isinstance(arg, torch.Tensor)},
74-
attrs=backend.get_attrs_descriptor(args, kernel.params),
72+
attrs=kernel._get_config(*args, ),
7573
)
7674

7775
context = triton._C.libtriton.ir.context()
76+
target = triton.runtime.driver.active.get_current_target()
77+
backend = triton.compiler.compiler.make_backend(target)
7878
options = backend.parse_options(dict())
7979
codegen_fns = dict()
8080
module_map = backend.get_module_map()

python/test/unit/runtime/test_subproc.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import triton
55
import triton.language as tl
6-
from triton.backends.compiler import AttrsDescriptor
76
from triton.compiler import ASTSource
87

98
target = triton.runtime.driver.active.get_current_target()
@@ -26,7 +25,7 @@ def kernel_sub(a, b, o, N: tl.constexpr):
2625

2726

2827
def test_compile_in_subproc() -> None:
29-
config = AttrsDescriptor.from_hints({i: 16 for i in range(4)})
28+
config = triton.compiler.AttrsDescriptor(tuple(range(4)), ())
3029
multiprocessing.set_start_method('fork')
3130
proc = multiprocessing.Process(target=compile_fn, args=(config, ))
3231
proc.start()
@@ -48,7 +47,7 @@ def kernel_dot(Z):
4847

4948

5049
def test_compile_in_forked_subproc(fresh_triton_cache) -> None:
51-
config = AttrsDescriptor.from_hints({0: 16})
50+
config = triton.compiler.AttrsDescriptor(tuple(range(1)), ())
5251
assert multiprocessing.get_start_method() == 'fork'
5352
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, ))
5453
proc.start()
@@ -87,7 +86,7 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None:
8786
gc.disable()
8887

8988
# stage 1.p
90-
config = AttrsDescriptor.from_hints({0: 16})
89+
config = triton.compiler.AttrsDescriptor(tuple(range(1)), ())
9190
compile_empty_kernel_with_gc(config)
9291

9392
# stage 2.p

python/triton/backends/compiler.py

Lines changed: 1 addition & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import re
3-
import hashlib
43
import subprocess
54

65
from abc import ABCMeta, abstractmethod, abstractclassmethod
@@ -9,184 +8,6 @@
98
from types import ModuleType
109

1110

12-
class AttrsDescriptor:
13-
"""
14-
This class handles compile-time properties for specific function parameters.
15-
16-
Different backends can add more properties to the common ones. The class
17-
contains two fields:
18-
19-
`arg_properties`: a dictionary containing the different compile-time properties for different
20-
parameters. I.e., the dictionary is a map from property names to parameter indices
21-
{
22-
"prop0": (0, 2, 3)
23-
"prop1": (0, 4, 5)
24-
}
25-
Different backends might need different properties on those paraemters to enable
26-
specific optimizations. The common compile time properties contained in this class
27-
are :
28-
- "tt.divisibility", i.e., is the given parameter divisible by 16
29-
- "tt.equal_to_1", i.e., is the given parameter an integer constant 1
30-
31-
`property_values`: a dictionary containing the value of the different compile-time properties, like:
32-
{
33-
"prop0": val0
34-
"prop1": val1
35-
}
36-
37-
`constant_properties`: a set containing the properties that can be used to determine if a parameter is constant
38-
39-
"""
40-
__slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties')
41-
42-
def __init__(self, params=None, values=None):
43-
"""
44-
Initialize the compile-time properties
45-
46-
We can initialize the AttrsDescriptor class by passing the list of params
47-
of the function and their `values`. The function will try to apply the properties
48-
to the values and save the parameters in the `arg_properties` list. If we don't pass
49-
either the `params` or the `values` we should initialize the class via an alternative method
50-
(see `from_dict` or `from_hints`)
51-
"""
52-
# Default initialization
53-
self.arg_properties = {}
54-
self.property_values = {}
55-
self.constant_properties = set()
56-
57-
self._add_common_properties(params, values)
58-
self._add_backend_properties(params, values)
59-
self._init_slots()
60-
61-
def _add_common_properties(self, params, values):
62-
""" Add common compile-time properties """
63-
self.property_values["tt.divisibility"] = 16
64-
self.property_values["tt.equal_to"] = 1
65-
self.constant_properties.add("tt.equal_to")
66-
67-
if (params is None) or (values is None):
68-
return
69-
70-
# Compile properties deduction
71-
assert (len(params) == len(values))
72-
73-
# Divisibility property
74-
self.arg_properties["tt.divisibility"] = [
75-
param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg)
76-
and not param.do_not_specialize and not param.do_not_specialize_on_alignment
77-
]
78-
79-
# Equal to 1 property
80-
self.arg_properties["tt.equal_to"] = [
81-
param.num
82-
for param, arg in zip(params, values)
83-
if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize
84-
]
85-
86-
def _add_backend_properties(self, params=None, values=None):
87-
""" This method is for different subclasses to implement their own compile-time properties """
88-
pass
89-
90-
def _init_slots(self):
91-
""" Initialize the slots of this class """
92-
for name, val in self.arg_properties.items():
93-
setattr(self, name.removeprefix('tt.') + '_' + str(self.property_values[name]), val)
94-
95-
def get_fn_attrs(self) -> Dict:
96-
"""
97-
Get the function attributes as a dictionary.
98-
99-
The returned dictionary will look like :
100-
{
101-
"arg0" : [(prop_name00, val00), (prop_name01, val01), ...)]}
102-
"arg1" : [(prop_name10, val10), (prop_name11, val11), ...)]}
103-
}
104-
"""
105-
attrs = {}
106-
for prop_name, arg_set in self.arg_properties.items():
107-
prop_val = self.property_values[prop_name]
108-
for arg in arg_set:
109-
attrs[arg] = attrs.get(arg, []) + [(prop_name, prop_val)]
110-
return attrs
111-
112-
def get_constants(self) -> Dict:
113-
""" Return a mapping of constant parameters to their values """
114-
constants = {}
115-
for prop_name in self.constant_properties:
116-
for p in self.arg_properties.get(prop_name, []):
117-
constants[p] = self.property_values[prop_name]
118-
return constants
119-
120-
def filter_out_constants(self):
121-
""" Return the same object, without properties marked as constants"""
122-
import copy
123-
c = copy.deepcopy(self)
124-
for prop_name in c.constant_properties:
125-
c.arg_properties.pop(prop_name, None)
126-
c.property_values.pop(prop_name, None)
127-
c.constant_properties = {}
128-
return c
129-
130-
def hash(self):
131-
values = [sorted(self.arg_properties.values())]
132-
values += [sorted(self.property_values.values())]
133-
values += [sorted(self.constant_properties)]
134-
key = str(values)
135-
return hashlib.sha256(key.encode("utf-8")).hexdigest()
136-
137-
def to_dict(self):
138-
return self.arg_properties
139-
140-
@staticmethod
141-
def from_dict(data):
142-
attrsDescriptor = AttrsDescriptor()
143-
for prop_name, param_ids in data.items():
144-
attrsDescriptor.arg_properties[prop_name] = param_ids
145-
attrsDescriptor._init_slots()
146-
return attrsDescriptor
147-
148-
@staticmethod
149-
def from_hints(hints: list[tuple[int, int]]):
150-
"""
151-
Create the class from a set of hints that are passed in.
152-
153-
Instead of deducing the properties from a list of paramaters and values,
154-
the user can pass in a list of `hints=[(param_index, val)]` and if `val`
155-
matches one of the values of the properties (e.g., `prop_val[prop0]`),
156-
then we insert `param_index` into the correct list (e.g., in
157-
`arg_properties[prop0]`)
158-
"""
159-
attrsDescriptor = AttrsDescriptor()
160-
for prop_name, prop_val in attrsDescriptor.property_values.items():
161-
attrsDescriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val]
162-
attrsDescriptor._init_slots()
163-
return attrsDescriptor
164-
165-
@staticmethod
166-
def is_divisible_by_16(x):
167-
""" Return if the argument is a multiple of 16"""
168-
if hasattr(x, "data_ptr"):
169-
return x.data_ptr() % 16 == 0
170-
elif isinstance(x, int):
171-
return x % 16 == 0
172-
if x is None:
173-
return True
174-
return False
175-
176-
@staticmethod
177-
def is_equal_to_1(x):
178-
""" Return if the argument is a constant 1"""
179-
return True if isinstance(x, int) and not isinstance(x, bool) and x == 1 else False
180-
181-
@staticmethod
182-
def get_property_key(val, align):
183-
if align and AttrsDescriptor.is_divisible_by_16(val):
184-
return "D"
185-
if AttrsDescriptor.is_equal_to_1(val):
186-
return "1"
187-
return "N"
188-
189-
19011
@dataclass(frozen=True)
19112
class GPUTarget(object):
19213
# Target backend, e.g., cuda, hip
@@ -258,20 +79,6 @@ def load_dialects(self, context):
25879
@abstractmethod
25980
def get_module_map(self) -> Dict[str, ModuleType]:
26081
"""
261-
Return a map of interface modules to their device-specific implementations
82+
Return a map of interface modules to their device-specific implementations.
26283
"""
26384
raise NotImplementedError
264-
265-
def get_attrs_descriptor(self, params, args):
266-
"""
267-
Return an attribute descriptor: given a set of parameters and arguments
268-
the descriptor stores a set of compile time properties that can improve code
269-
generation. Different backends might benefit from different properties
270-
"""
271-
return AttrsDescriptor(params, args)
272-
273-
def compute_spec_key(self, arg, align):
274-
"""
275-
Return the ascii key for a given argument with a given set of properties
276-
"""
277-
return AttrsDescriptor.get_property_key(arg, align)

python/triton/compiler/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict
1+
from .compiler import CompiledKernel, ASTSource, compile, AttrsDescriptor, make_backend, LazyDict
22
from .errors import CompilationError
33

44
__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"]

python/triton/compiler/code_generator.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,7 +1270,7 @@ def kernel_suffix(signature, specialization):
12701270
suffix += str(i)
12711271
if i in specialization.equal_to_1:
12721272
suffix += 'c'
1273-
if i in specialization.divisibility_16:
1273+
if i in specialization.divisible_by_16:
12741274
suffix += 'd'
12751275
return suffix
12761276

@@ -1284,21 +1284,17 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map):
12841284
gscope = fn.__globals__.copy()
12851285
function_name = fn.repr(specialization)
12861286
tys = list(specialization.signature.values())
1287-
new_constants = attrs.get_constants()
1288-
for k in new_constants:
1289-
if k in tys and tys[k] == "i1" and new_constants[k] == 1:
1290-
new_constants[k] = True
1287+
new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in attrs.equal_to_1}
1288+
new_attrs = {k: [("tt.divisibility", 16)] for k in attrs.divisible_by_16}
12911289

1292-
new_attrs = attrs.filter_out_constants()
1293-
fn_attrs = new_attrs.get_fn_attrs()
12941290
all_constants = constants.copy()
12951291
all_constants.update(new_constants)
12961292
arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants]
12971293
file_name, begin_line = get_jit_fn_file_line(fn)
12981294

12991295
prototype = language.function_type([], arg_types)
13001296
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
1301-
jit_fn=fn, attributes=fn_attrs, is_kernel=True, file_name=file_name,
1297+
jit_fn=fn, attributes=new_attrs, is_kernel=True, file_name=file_name,
13021298
begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map)
13031299
generator.visit(fn.parse())
13041300

python/triton/compiler/compiler.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,45 @@
33
import json
44
from .._C.libtriton import get_cache_invalidating_env_vars, ir
55
from ..backends import backends
6-
from ..backends.compiler import GPUTarget, AttrsDescriptor
6+
from ..backends.compiler import GPUTarget
77
from .. import __version__
88
from ..runtime.autotuner import OutOfResources
99
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
1010
from ..runtime.driver import driver
1111
from ..tools.disasm import get_sass
1212
# TODO: this shouldn't be here
13+
from dataclasses import dataclass
1314
from .code_generator import ast_to_ttir
1415
from pathlib import Path
1516
import re
1617
import functools
1718
import os
1819

20+
21+
@dataclass
22+
class AttrsDescriptor:
23+
divisible_by_16: set = None
24+
equal_to_1: set = None
25+
26+
def __post_init__(self):
27+
if self.divisible_by_16 is None:
28+
self.divisible_by_16 = set()
29+
if self.equal_to_1 is None:
30+
self.equal_to_1 = set()
31+
32+
def to_dict(self):
33+
return {'divisible_by_16': list(self.divisible_by_16), 'equal_to_1': list(self.equal_to_1)}
34+
35+
@staticmethod
36+
def from_dict(data):
37+
return AttrsDescriptor(divisible_by_16=set(data.get('divisible_by_16', [])),
38+
equal_to_1=set(data.get('equal_to_1', [])))
39+
40+
def hash(self):
41+
key = str([sorted(x) for x in self.__dict__.values()])
42+
return hashlib.sha256(key.encode("utf-8")).hexdigest()
43+
44+
1945
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
2046
# and any following whitespace
2147
# - (public\s+)? : optionally match the keyword public and any following whitespace

0 commit comments

Comments
 (0)