Skip to content

Commit caf24cc

Browse files
authored
Merge OpenAI Triton commit 67dc627 (#3112)
2 parents 292c08f + 8697d29 commit caf24cc

File tree

17 files changed

+314
-512
lines changed

17 files changed

+314
-512
lines changed

python/test/unit/language/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4519,7 +4519,7 @@ def test_value_specialization(value: int, value_type: str, device) -> None:
45194519

45204520
def repr(specialization):
45214521
ty = specialization.signature["value1"]
4522-
cst = '_'.join([k for k, v in specialization.constants.items() if v == 1])
4522+
cst = '_'.join([k for k, v in specialization.constants.items() if isinstance(k, str) and v == 1])
45234523
return f"kernel_{ty}_{cst}"
45244524

45254525
@triton.jit(repr=repr)

python/test/unit/language/test_tuple.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1):
8181
def _tuple_serialize(Ptr, N1, tuple1, cst1: tl.constexpr, val1, tuple2):
8282
tl.static_assert(N1 is None)
8383
tl.static_assert(tuple1[1][1] is None)
84+
tl.static_assert(tuple1[1][3] == 4)
8485
tl.store(Ptr + 0, tl.load(tuple1[0]))
8586
tl.store(Ptr + 1, tuple1[1][0])
8687
tl.store(Ptr + 2, tl.load(tuple1[1][2]))
@@ -95,6 +96,6 @@ def test_serialize(device="xpu"):
9596
y0 = torch.tensor([10], dtype=torch.int32, device=device)
9697
z = torch.empty((10, ), dtype=torch.int32, device=device)
9798
# we want to check that JIT specialization propagates to tuples:
98-
_tuple_serialize[(1, )](z, None, (x0, (1, None, x1)), 20, 1, (y0, ))
99+
_tuple_serialize[(1, )](z, None, (x0, (1, None, x1, tl.constexpr(4))), 20, 1, (y0, ))
99100
ref = torch.tensor([8, 1, 12, 21, 10, 15, -1, 8, 1, 12], device=device)
100101
assert torch.equal(z, ref)

python/test/unit/runtime/test_bindings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def walk_fn(op):
6868
constexprs={kernel.arg_names[i]: arg
6969
for i, arg in enumerate(args)
7070
if not isinstance(arg, torch.Tensor)},
71-
attrs=backend.get_attrs_descriptor(kernel.params, args),
7271
)
7372

7473
context = triton._C.libtriton.ir.context()

python/test/unit/runtime/test_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def kernel_add(a):
563563

564564
def cache_hook(*args, **kwargs):
565565
nonlocal pointer_range_32
566-
pointer_range_32 = kwargs["compile"]["configs"][0].pointer_range_32
566+
pointer_range_32 = [k for k, v in kwargs["compile"]["configs"][0].items() if ['tt.pointer_range', 32] in v]
567567

568568
JITFunction.cache_hook = cache_hook
569569
# In warmup we assume that the pointer range is 32 bits

python/test/unit/runtime/test_subproc.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33

44
import triton
55
import triton.language as tl
6-
from triton.backends.compiler import AttrsDescriptor
76
from triton.compiler import ASTSource
87

98
target = triton.runtime.driver.active.get_current_target()
109
start_method = 'fork' if 'fork' in multiprocessing.get_all_start_methods() else 'spawn'
1110

1211

13-
def compile_fn(attrs):
12+
def compile_fn():
1413

1514
@triton.jit
1615
def kernel_sub(a, b, o, N: tl.constexpr):
@@ -21,21 +20,19 @@ def kernel_sub(a, b, o, N: tl.constexpr):
2120
fn=kernel_sub,
2221
constexprs={'N': 32},
2322
signature={'a': "*fp32", 'b': "*fp32", 'o': "*fp32", 'N': 'constexpr'},
24-
attrs=attrs,
2523
)
2624
triton.compile(src=src, target=target)
2725

2826

2927
def test_compile_in_subproc() -> None:
30-
config = AttrsDescriptor.from_hints({i: 16 for i in range(4)})
3128
mp_ctx = multiprocessing.get_context(start_method)
32-
proc = mp_ctx.Process(target=compile_fn, args=(config, ))
29+
proc = mp_ctx.Process(target=compile_fn)
3330
proc.start()
3431
proc.join()
3532
assert proc.exitcode == 0
3633

3734

38-
def compile_fn_dot(attrs):
35+
def compile_fn_dot():
3936

4037
@triton.jit
4138
def kernel_dot(Z):
@@ -44,28 +41,27 @@ def kernel_dot(Z):
4441
z = tl.dot(z, z)
4542
tl.store(Z + offs, z)
4643

47-
src = ASTSource(fn=kernel_dot, signature={'Z': "*fp32"}, attrs=attrs, constexprs={})
44+
src = ASTSource(fn=kernel_dot, signature={'Z': "*fp32"})
4845
triton.compile(src=src, target=target)
4946

5047

5148
def test_compile_in_forked_subproc(fresh_triton_cache) -> None:
52-
config = AttrsDescriptor.from_hints({0: 16})
5349
mp_ctx = multiprocessing.get_context(start_method)
54-
proc = mp_ctx.Process(target=compile_fn_dot, args=(config, ))
50+
proc = mp_ctx.Process(target=compile_fn_dot)
5551
proc.start()
5652
proc.join()
5753
assert proc.exitcode == 0
5854

5955

60-
def compile_empty_kernel_with_gc(attrs):
56+
def compile_empty_kernel_with_gc():
6157

6258
@triton.jit
6359
def empty_kernel():
6460
pass
6561

6662
import gc
6763
gc.collect()
68-
src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constexprs={})
64+
src = ASTSource(fn=empty_kernel, signature={})
6965
triton.compile(src=src, target=target)
7066

7167

@@ -88,13 +84,12 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None:
8884
gc.disable()
8985

9086
# stage 1.p
91-
config = AttrsDescriptor.from_hints({0: 16})
92-
compile_empty_kernel_with_gc(config)
87+
compile_empty_kernel_with_gc()
9388

9489
# stage 2.p
9590
shutil.rmtree(fresh_triton_cache, ignore_errors=True)
9691
mp_ctx = multiprocessing.get_context(start_method)
97-
proc = mp_ctx.Process(target=compile_empty_kernel_with_gc, args=(config, ))
92+
proc = mp_ctx.Process(target=compile_empty_kernel_with_gc)
9893

9994
# stage 3.c
10095
proc.start()

python/triton/backends/compiler.py

Lines changed: 16 additions & 235 deletions
Original file line numberDiff line numberDiff line change
@@ -1,235 +1,11 @@
11
import os
22
import re
3-
import hashlib
43
import subprocess
54
import sysconfig
65
from abc import ABCMeta, abstractmethod
76
from dataclasses import dataclass
8-
from typing import Dict, List, Tuple, Union
7+
from typing import Dict, Union
98
from types import ModuleType
10-
from .._utils import find_paths_if
11-
12-
# Table that associates strings to AttrsDescriptor (sub)classes.
13-
# In this way we can dynamically select the correct class
14-
# constructor
15-
_descriptor_table = {}
16-
17-
18-
def register_descriptor(cls):
19-
"""
20-
Register a descriptor into the descriptor table
21-
"""
22-
_descriptor_table[cls.__name__] = cls
23-
return cls
24-
25-
26-
@register_descriptor
27-
class AttrsDescriptor:
28-
"""
29-
This class handles compile-time properties for specific function parameters.
30-
31-
Different backends can add more properties to the common ones. The class
32-
contains two fields:
33-
34-
`arg_properties`: a dictionary containing the different compile-time properties for different
35-
parameters. I.e., the dictionary is a map from property names to parameter indices
36-
{
37-
"prop0": (0, 2, 3)
38-
"prop1": (0, 4, 5)
39-
}
40-
Different backends might need different properties on those paraemters to enable
41-
specific optimizations. The common compile time properties contained in this class
42-
are :
43-
- "tt.divisibility", i.e., is the given parameter divisible by 16
44-
- "tt.equal_to_1", i.e., is the given parameter an integer constant 1
45-
46-
`property_values`: a dictionary containing the value of the different compile-time properties, like:
47-
{
48-
"prop0": val0
49-
"prop1": val1
50-
}
51-
52-
`constant_properties`: a set containing the properties that can be used to determine if a parameter is constant
53-
54-
"""
55-
__slots__ = ('divisibility_16', 'equal_to_1', 'equal_to_none', 'arg_properties', 'property_values',
56-
'constant_properties')
57-
58-
def __init__(self, params=None, values=None):
59-
"""
60-
Initialize the compile-time properties
61-
62-
We can initialize the AttrsDescriptor class by passing the list of params
63-
of the function and their `values`. The function will try to apply the properties
64-
to the values and save the parameters in the `arg_properties` list. If we don't pass
65-
either the `params` or the `values` we should initialize the class via an alternative method
66-
(see `from_dict` or `from_hints`)
67-
"""
68-
# Default initialization
69-
self.arg_properties = {}
70-
self.property_values = {}
71-
self.equal_to_none = {}
72-
self.constant_properties = set()
73-
74-
self._add_common_properties(params, values)
75-
self._add_backend_properties(params, values)
76-
self._init_slots()
77-
78-
def _add_common_properties(self, params, values):
79-
""" Add common compile-time properties """
80-
self.property_values["tt.divisibility"] = 16
81-
self.property_values["tt.equal_to"] = 1
82-
self.constant_properties.add("tt.equal_to")
83-
84-
if (params is None) or (values is None):
85-
return
86-
87-
# Compile properties deduction
88-
assert (len(params) == len(values))
89-
90-
# Divisibility property
91-
divisibility_16 = []
92-
for param, arg in zip(params, values):
93-
if param.do_not_specialize or \
94-
param.do_not_specialize_on_alignment:
95-
continue
96-
paths = find_paths_if(arg, lambda path, val: AttrsDescriptor.is_divisible_by_16(val))
97-
divisibility_16 += [(param.num, ) + x for x in paths]
98-
self.arg_properties["tt.divisibility"] = divisibility_16
99-
100-
# Equal to 1 property
101-
equal_to_1 = []
102-
for param, arg in zip(params, values):
103-
if param.do_not_specialize:
104-
continue
105-
paths = find_paths_if(arg, lambda path, val: AttrsDescriptor.is_equal_to_1(val))
106-
equal_to_1 += [(param.num, ) + x for x in paths]
107-
self.arg_properties["tt.equal_to"] = equal_to_1
108-
109-
# Equal to None property
110-
equal_to_none = []
111-
for param, arg in zip(params, values):
112-
paths = find_paths_if(arg, lambda path, val: val is None)
113-
equal_to_none += [(param.num, ) + x for x in paths]
114-
self.equal_to_none = equal_to_none
115-
116-
def _add_backend_properties(self, params=None, values=None):
117-
""" This method is for different subclasses to implement their own compile-time properties """
118-
pass
119-
120-
def _init_slots(self):
121-
""" Initialize the slots of this class """
122-
for name, val in self.arg_properties.items():
123-
setattr(self, name.removeprefix('tt.') + '_' + str(self.property_values[name]), val)
124-
125-
def get_fn_attrs(self) -> Dict:
126-
"""
127-
Get the function attributes as a dictionary.
128-
129-
The returned dictionary will look like :
130-
{
131-
"arg0" : [(prop_name00, val00), (prop_name01, val01), ...)]}
132-
"arg1" : [(prop_name10, val10), (prop_name11, val11), ...)]}
133-
}
134-
"""
135-
attrs = {}
136-
for prop_name, arg_set in self.arg_properties.items():
137-
prop_val = self.property_values[prop_name]
138-
for arg in arg_set:
139-
attrs[arg] = attrs.get(arg, []) + [(prop_name, prop_val)]
140-
return attrs
141-
142-
def get_constants(self) -> Dict:
143-
""" Return a mapping of constant parameters to their values """
144-
constants = {}
145-
for prop_name in self.constant_properties:
146-
for p in self.arg_properties.get(prop_name, []):
147-
constants[p] = self.property_values[prop_name]
148-
for v in self.equal_to_none:
149-
constants[v] = None
150-
return constants
151-
152-
def filter_out_constants(self):
153-
""" Return the same object, without properties marked as constants"""
154-
import copy
155-
c = copy.deepcopy(self)
156-
for prop_name in c.constant_properties:
157-
c.arg_properties.pop(prop_name, None)
158-
c.property_values.pop(prop_name, None)
159-
c.constant_properties = {}
160-
return c
161-
162-
def hash(self):
163-
values = [sorted(self.arg_properties.values())]
164-
values += [sorted(self.property_values.values())]
165-
values += [sorted(self.constant_properties)]
166-
key = str(values)
167-
return hashlib.sha256(key.encode("utf-8")).hexdigest()
168-
169-
def to_dict(self):
170-
"""
171-
Store the fields of this class in a serializable dictionary
172-
"""
173-
# We need to only store the `arg_properties` field. To initialize the
174-
# other fields we relay on the class type. We store it as a string in
175-
# the dictionary so that we can use it to invoke the appropriate
176-
# (sub)class constructor in the `from_dict` method.
177-
return {"arg_properties": self.arg_properties, "cls": type(self).__name__}
178-
179-
@staticmethod
180-
def from_dict(data):
181-
"""
182-
Create the object from a serializable dictionary
183-
"""
184-
attrs_descriptor = _descriptor_table[data["cls"]]()
185-
for prop_name, param_ids in data["arg_properties"].items():
186-
attrs_descriptor.arg_properties[prop_name] = list(map(tuple, param_ids))
187-
attrs_descriptor._init_slots()
188-
return attrs_descriptor
189-
190-
@classmethod
191-
def from_hints(cls, hints: List[Tuple[int, int]]):
192-
"""
193-
Create the class from a set of hints that are passed in.
194-
195-
Instead of deducing the properties from a list of paramaters and values,
196-
the user can pass in a list of `hints=[(param_index, val)]` and if `val`
197-
matches one of the values of the properties (e.g., `prop_val[prop0]`),
198-
then we insert `param_index` into the correct list (e.g., in
199-
`arg_properties[prop0]`)
200-
"""
201-
attrs_descriptor = cls()
202-
for prop_name, prop_val in attrs_descriptor.property_values.items():
203-
attrs_descriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val]
204-
attrs_descriptor._init_slots()
205-
return attrs_descriptor
206-
207-
@staticmethod
208-
def is_divisible_by_16(x):
209-
""" Return if the argument is a multiple of 16"""
210-
if hasattr(x, "data_ptr"):
211-
return x.data_ptr() % 16 == 0
212-
elif isinstance(x, int):
213-
return x % 16 == 0
214-
if x is None:
215-
return True
216-
return False
217-
218-
@staticmethod
219-
def is_equal_to_1(x):
220-
""" Return if the argument is a constant 1"""
221-
return True if isinstance(x, int) and not isinstance(x, bool) and x == 1 else False
222-
223-
@staticmethod
224-
def get_property_key(val, align):
225-
if align and AttrsDescriptor.is_divisible_by_16(val):
226-
return "D"
227-
if AttrsDescriptor.is_equal_to_1(val):
228-
return "1"
229-
return "N"
230-
231-
def __repr__(self):
232-
return f"AttrsDescriptor.from_dict({self.to_dict()!r})"
2339

23410

23511
@dataclass(frozen=True)
@@ -308,16 +84,21 @@ def get_module_map(self) -> Dict[str, ModuleType]:
30884
"""
30985
raise NotImplementedError
31086

311-
def get_attrs_descriptor(self, params, args):
312-
"""
313-
Return an attribute descriptor: given a set of parameters and arguments
314-
the descriptor stores a set of compile time properties that can improve code
315-
generation. Different backends might benefit from different properties
316-
"""
317-
return AttrsDescriptor(params, args)
87+
@staticmethod
88+
def parse_attr(desc):
89+
assert isinstance(desc, str)
90+
ret = []
91+
if "D" in desc:
92+
ret += [["tt.divisibility", 16]]
93+
return ret
31894

319-
def compute_spec_key(self, arg, align):
95+
@staticmethod
96+
def get_arg_specialization(arg, ty, **kwargs):
32097
"""
321-
Return the ascii key for a given argument with a given set of properties
98+
Return a string unique to each possible specialization of the argument
32299
"""
323-
return AttrsDescriptor.get_property_key(arg, align)
100+
if ty == "int" and arg % 16 == 0 and kwargs.get("align", False):
101+
return "D"
102+
if ty == "tensor" and arg.data_ptr() % 16 == 0 and kwargs.get("align", False):
103+
return "D"
104+
return ""

0 commit comments

Comments
 (0)