Skip to content

Commit 6293c4b

Browse files
Support new AttrsDescriptor (#2487)
2 parents a9ca5f0 + c6e60f0 commit 6293c4b

File tree

11 files changed

+235
-94
lines changed

11 files changed

+235
-94
lines changed

.github/actions/setup-pytorch/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ runs:
8282
uses: ./.github/actions/load
8383
env:
8484
# Increase this value to reset cache
85-
CACHE_NUMBER: 11
85+
CACHE_NUMBER: 12
8686
with:
8787
path: pytorch
8888
key: pytorch-$PYTORCH_CACHE_KEY-$CACHE_NUMBER

.github/pins/pytorch-upstream.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
8321eec009c8c79145ebccd51fdfc336e5f8b848
1+
487873f7cafeb0fd390eaefe40496b804bceabbd

python/test/unit/runtime/test_bindings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def walk_fn(op):
5959
torch.empty((32, 32), device=device), # out_ptr
6060
16, # BLOCK_SIZE
6161
]
62+
target = triton.runtime.driver.active.get_current_target()
63+
backend = triton.compiler.compiler.make_backend(target)
6264
src = triton.compiler.compiler.ASTSource(
6365
fn=kernel,
6466
signature={
@@ -69,12 +71,10 @@ def walk_fn(op):
6971
constants={kernel.arg_names[i]: arg
7072
for i, arg in enumerate(args)
7173
if not isinstance(arg, torch.Tensor)},
72-
attrs=kernel._get_config(*args, ),
74+
attrs=backend.get_attrs_descriptor(args, kernel.params),
7375
)
7476

7577
context = triton._C.libtriton.ir.context()
76-
target = triton.runtime.driver.active.get_current_target()
77-
backend = triton.compiler.compiler.make_backend(target)
7878
options = backend.parse_options(dict())
7979
codegen_fns = dict()
8080
module_map = backend.get_module_map()

python/test/unit/runtime/test_subproc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import triton
55
import triton.language as tl
6+
from triton.backends.compiler import AttrsDescriptor
67
from triton.compiler import ASTSource
78

89
target = triton.runtime.driver.active.get_current_target()
@@ -25,7 +26,7 @@ def kernel_sub(a, b, o, N: tl.constexpr):
2526

2627

2728
def test_compile_in_subproc() -> None:
28-
config = triton.compiler.AttrsDescriptor(tuple(range(4)), ())
29+
config = AttrsDescriptor.from_hints({i: 16 for i in range(4)})
2930
multiprocessing.set_start_method('fork')
3031
proc = multiprocessing.Process(target=compile_fn, args=(config, ))
3132
proc.start()
@@ -47,7 +48,7 @@ def kernel_dot(Z):
4748

4849

4950
def test_compile_in_forked_subproc(fresh_triton_cache) -> None:
50-
config = triton.compiler.AttrsDescriptor(tuple(range(1)), ())
51+
config = AttrsDescriptor.from_hints({0: 16})
5152
assert multiprocessing.get_start_method() == 'fork'
5253
proc = multiprocessing.Process(target=compile_fn_dot, args=(config, ))
5354
proc.start()
@@ -86,7 +87,7 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None:
8687
gc.disable()
8788

8889
# stage 1.p
89-
config = triton.compiler.AttrsDescriptor(tuple(range(1)), ())
90+
config = AttrsDescriptor.from_hints({0: 16})
9091
compile_empty_kernel_with_gc(config)
9192

9293
# stage 2.p

python/triton/backends/compiler.py

Lines changed: 197 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import re
3+
import hashlib
34
import subprocess
45

56
from abc import ABCMeta, abstractmethod, abstractclassmethod
@@ -8,6 +9,187 @@
89
from types import ModuleType
910

1011

12+
class AttrsDescriptor:
13+
"""
14+
This class handles compile-time properties for specific function parameters.
15+
16+
Different backends can add more properties to the common ones. The class
17+
contains two fields:
18+
19+
`arg_properties`: a dictionary containing the different compile-time properties for different
20+
parameters. I.e., the dictionary is a map from property names to parameter indices
21+
{
22+
"prop0": (0, 2, 3)
23+
"prop1": (0, 4, 5)
24+
}
25+
Different backends might need different properties on those paraemters to enable
26+
specific optimizations. The common compile time properties contained in this class
27+
are :
28+
- "tt.divisibility", i.e., is the given parameter divisible by 16
29+
- "tt.equal_to_1", i.e., is the given parameter an integer constant 1
30+
31+
`property_values`: a dictionary containing the value of the different compile-time properties, like:
32+
{
33+
"prop0": val0
34+
"prop1": val1
35+
}
36+
37+
`constant_properties`: a set containing the properties that can be used to determine if a parameter is constant
38+
39+
"""
40+
__slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties')
41+
42+
def __init__(self, params=None, values=None):
43+
"""
44+
Initialize the compile-time properties
45+
46+
We can initialize the AttrsDescriptor class by passing the list of params
47+
of the function and their `values`. The function will try to apply the properties
48+
to the values and save the parameters in the `arg_properties` list. If we don't pass
49+
either the `params` or the `values` we should initialize the class via an alternative method
50+
(see `from_dict` or `from_hints`)
51+
"""
52+
# Default initialization
53+
self.arg_properties = {}
54+
self.property_values = {}
55+
self.constant_properties = set()
56+
57+
self._add_common_properties(params, values)
58+
self._add_backend_properties(params, values)
59+
self._init_slots()
60+
61+
def _add_common_properties(self, params, values):
62+
""" Add common compile-time properties """
63+
self.property_values["tt.divisibility"] = 16
64+
self.property_values["tt.equal_to"] = 1
65+
self.constant_properties.add("tt.equal_to")
66+
67+
if (params is None) or (values is None):
68+
return
69+
70+
# Compile properties deduction
71+
assert (len(params) == len(values))
72+
73+
# Divisibility property
74+
self.arg_properties["tt.divisibility"] = [
75+
param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg)
76+
and not param.do_not_specialize and not param.do_not_specialize_on_alignment
77+
]
78+
79+
# Equal to 1 property
80+
self.arg_properties["tt.equal_to"] = [
81+
param.num
82+
for param, arg in zip(params, values)
83+
if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize
84+
]
85+
86+
def _add_backend_properties(self, params=None, values=None):
87+
""" This method is for different subclasses to implement their own compile-time properties """
88+
pass
89+
90+
def _init_slots(self):
91+
""" Initialize the slots of this class """
92+
for name, val in self.arg_properties.items():
93+
setattr(self, name.removeprefix('tt.') + '_' + str(self.property_values[name]), val)
94+
95+
def get_fn_attrs(self) -> Dict:
96+
"""
97+
Get the function attributes as a dictionary.
98+
99+
The returned dictionary will look like :
100+
{
101+
"arg0" : [(prop_name00, val00), (prop_name01, val01), ...)]}
102+
"arg1" : [(prop_name10, val10), (prop_name11, val11), ...)]}
103+
}
104+
"""
105+
attrs = {}
106+
for prop_name, arg_set in self.arg_properties.items():
107+
prop_val = self.property_values[prop_name]
108+
for arg in arg_set:
109+
attrs[arg] = attrs.get(arg, []) + [(prop_name, prop_val)]
110+
return attrs
111+
112+
def get_constants(self) -> Dict:
113+
""" Return a mapping of constant parameters to their values """
114+
constants = {}
115+
for prop_name in self.constant_properties:
116+
for p in self.arg_properties.get(prop_name, []):
117+
constants[p] = self.property_values[prop_name]
118+
return constants
119+
120+
def filter_out_constants(self):
121+
""" Return the same object, without properties marked as constants"""
122+
import copy
123+
c = copy.deepcopy(self)
124+
for prop_name in c.constant_properties:
125+
c.arg_properties.pop(prop_name, None)
126+
c.property_values.pop(prop_name, None)
127+
c.constant_properties = {}
128+
return c
129+
130+
def hash(self):
131+
values = [sorted(self.arg_properties.values())]
132+
values += [sorted(self.property_values.values())]
133+
values += [sorted(self.constant_properties)]
134+
key = str(values)
135+
return hashlib.sha256(key.encode("utf-8")).hexdigest()
136+
137+
def to_dict(self):
138+
return self.arg_properties
139+
140+
@staticmethod
141+
def from_dict(data):
142+
attrsDescriptor = AttrsDescriptor()
143+
for prop_name, param_ids in data.items():
144+
attrsDescriptor.arg_properties[prop_name] = param_ids
145+
attrsDescriptor._init_slots()
146+
return attrsDescriptor
147+
148+
@staticmethod
149+
def from_hints(hints: list[tuple[int, int]]):
150+
"""
151+
Create the class from a set of hints that are passed in.
152+
153+
Instead of deducing the properties from a list of paramaters and values,
154+
the user can pass in a list of `hints=[(param_index, val)]` and if `val`
155+
matches one of the values of the properties (e.g., `prop_val[prop0]`),
156+
then we insert `param_index` into the correct list (e.g., in
157+
`arg_properties[prop0]`)
158+
"""
159+
attrsDescriptor = AttrsDescriptor()
160+
for prop_name, prop_val in attrsDescriptor.property_values.items():
161+
attrsDescriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val]
162+
attrsDescriptor._init_slots()
163+
return attrsDescriptor
164+
165+
@staticmethod
166+
def is_divisible_by_16(x):
167+
""" Return if the argument is a multiple of 16"""
168+
if hasattr(x, "data_ptr"):
169+
return x.data_ptr() % 16 == 0
170+
elif isinstance(x, int):
171+
return x % 16 == 0
172+
if x is None:
173+
return True
174+
return False
175+
176+
@staticmethod
177+
def is_equal_to_1(x):
178+
""" Return if the argument is a constant 1"""
179+
return True if isinstance(x, int) and not isinstance(x, bool) and x == 1 else False
180+
181+
@staticmethod
182+
def get_property_key(val, align):
183+
if align and AttrsDescriptor.is_divisible_by_16(val):
184+
return "D"
185+
if AttrsDescriptor.is_equal_to_1(val):
186+
return "1"
187+
return "N"
188+
189+
def __repr__(self):
190+
return f"AttrsDescriptor.from_dict({self.arg_properties})"
191+
192+
11193
@dataclass(frozen=True)
12194
class GPUTarget(object):
13195
# Target backend, e.g., cuda, hip
@@ -79,6 +261,20 @@ def load_dialects(self, context):
79261
@abstractmethod
80262
def get_module_map(self) -> Dict[str, ModuleType]:
81263
"""
82-
Return a map of interface modules to their device-specific implementations.
264+
Return a map of interface modules to their device-specific implementations
83265
"""
84266
raise NotImplementedError
267+
268+
def get_attrs_descriptor(self, params, args):
269+
"""
270+
Return an attribute descriptor: given a set of parameters and arguments
271+
the descriptor stores a set of compile time properties that can improve code
272+
generation. Different backends might benefit from different properties
273+
"""
274+
return AttrsDescriptor(params, args)
275+
276+
def compute_spec_key(self, arg, align):
277+
"""
278+
Return the ascii key for a given argument with a given set of properties
279+
"""
280+
return AttrsDescriptor.get_property_key(arg, align)

python/triton/compiler/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .compiler import CompiledKernel, ASTSource, compile, AttrsDescriptor, make_backend, LazyDict
1+
from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict
22
from .errors import CompilationError
33

44
__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"]

python/triton/compiler/code_generator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,7 +1265,7 @@ def kernel_suffix(signature, specialization):
12651265
suffix += str(i)
12661266
if i in specialization.equal_to_1:
12671267
suffix += 'c'
1268-
if i in specialization.divisible_by_16:
1268+
if i in specialization.divisibility_16:
12691269
suffix += 'd'
12701270
return suffix
12711271

@@ -1279,17 +1279,21 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map):
12791279
gscope = fn.__globals__.copy()
12801280
function_name = fn.repr(specialization)
12811281
tys = list(specialization.signature.values())
1282-
new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in attrs.equal_to_1}
1283-
new_attrs = {k: [("tt.divisibility", 16)] for k in attrs.divisible_by_16}
1282+
new_constants = attrs.get_constants()
1283+
for k in new_constants:
1284+
if k in tys and tys[k] == "i1" and new_constants[k] == 1:
1285+
new_constants[k] = True
12841286

1287+
new_attrs = attrs.filter_out_constants()
1288+
fn_attrs = new_attrs.get_fn_attrs()
12851289
all_constants = constants.copy()
12861290
all_constants.update(new_constants)
12871291
arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants]
12881292
file_name, begin_line = get_jit_fn_file_line(fn)
12891293

12901294
prototype = language.function_type([], arg_types)
12911295
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
1292-
jit_fn=fn, attributes=new_attrs, is_kernel=True, file_name=file_name,
1296+
jit_fn=fn, attributes=fn_attrs, is_kernel=True, file_name=file_name,
12931297
begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map)
12941298
generator.visit(fn.parse())
12951299

python/triton/compiler/compiler.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,45 +3,19 @@
33
import json
44
from .._C.libtriton import get_cache_invalidating_env_vars, ir
55
from ..backends import backends
6-
from ..backends.compiler import GPUTarget
6+
from ..backends.compiler import GPUTarget, AttrsDescriptor
77
from .. import __version__
88
from ..runtime.autotuner import OutOfResources
99
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
1010
from ..runtime.driver import driver
1111
from ..tools.disasm import get_sass
1212
# TODO: this shouldn't be here
13-
from dataclasses import dataclass
1413
from .code_generator import ast_to_ttir
1514
from pathlib import Path
1615
import re
1716
import functools
1817
import os
1918

20-
21-
@dataclass
22-
class AttrsDescriptor:
23-
divisible_by_16: set = None
24-
equal_to_1: set = None
25-
26-
def __post_init__(self):
27-
if self.divisible_by_16 is None:
28-
self.divisible_by_16 = set()
29-
if self.equal_to_1 is None:
30-
self.equal_to_1 = set()
31-
32-
def to_dict(self):
33-
return {'divisible_by_16': list(self.divisible_by_16), 'equal_to_1': list(self.equal_to_1)}
34-
35-
@staticmethod
36-
def from_dict(data):
37-
return AttrsDescriptor(divisible_by_16=set(data.get('divisible_by_16', [])),
38-
equal_to_1=set(data.get('equal_to_1', [])))
39-
40-
def hash(self):
41-
key = str([sorted(x) for x in self.__dict__.values()])
42-
return hashlib.sha256(key.encode("utf-8")).hexdigest()
43-
44-
4519
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
4620
# and any following whitespace
4721
# - (public\s+)? : optionally match the keyword public and any following whitespace

0 commit comments

Comments
 (0)