Skip to content

Commit 2e359d3

Browse files
authored
[KERNEL] Add decorator to make caching play well with specialized kernel (#7634)
Decorator idea and implementation from @apgoucher. This allow specialized kernels to work with preload.
1 parent 9d64b33 commit 2e359d3

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import torch
2+
import importlib
3+
from triton_kernels.specialize import cacheable, specialize
4+
import triton
5+
import triton.language as tl
6+
7+
8+
@triton.jit
9+
def template_kernel(o):
10+
cst = 1.0
11+
tl.store(o, cst)
12+
13+
14+
def retrieve_fn(module, name):
15+
module = importlib.import_module(module)
16+
fn = getattr(module, name)
17+
return fn
18+
19+
20+
_specialized_kernel = None
21+
22+
23+
def get_specialized_kernel():
24+
global _specialized_kernel
25+
if _specialized_kernel is not None:
26+
return _specialized_kernel
27+
import types
28+
spec_constants = {}
29+
spec_tuples = {}
30+
module = types.ModuleType("specialized_kernel")
31+
module.specialized = specialize(template_kernel, module, spec_constants, spec_tuples)
32+
_specialized_kernel = module.specialized
33+
return _specialized_kernel
34+
35+
36+
@cacheable
37+
def cacheable_kernel():
38+
return get_specialized_kernel()
39+
40+
41+
def test_cacheable(device):
42+
specialized_kernel = get_specialized_kernel()
43+
44+
specialization_data = None
45+
fn_name = None
46+
module_name = None
47+
48+
def cache_hook(*args, **kwargs):
49+
nonlocal specialization_data
50+
nonlocal fn_name
51+
nonlocal module_name
52+
specialization_data = kwargs["compile"]["specialization_data"]
53+
fn_name = kwargs["fn"].name
54+
module_name = kwargs["fn"].module
55+
56+
triton.knobs.runtime.jit_cache_hook = cache_hook
57+
o = torch.empty((1, ), dtype=torch.float32, device=device)
58+
k = specialized_kernel[(1, )](o, )
59+
hash = k.hash
60+
assert o.item() == 1.0
61+
assert module_name == "tests.test_specialize"
62+
assert fn_name == "cacheable_kernel"
63+
64+
compile_count = 0
65+
66+
def count_hook(*args, **kwargs):
67+
nonlocal compile_count
68+
compile_count += 1
69+
70+
triton.knobs.runtime.jit_cache_hook = count_hook
71+
# clear the cache
72+
specialized_kernel.device_caches.clear()
73+
74+
# retrieve the kernel from name and preload it.
75+
fn = retrieve_fn(module_name, fn_name)
76+
assert fn == specialized_kernel
77+
preload = fn.preload(specialization_data)
78+
assert compile_count == 1
79+
assert preload.hash == hash
80+
81+
# verify that we hit the cache.
82+
compile_count = 0
83+
specialized_kernel[(1, )](o, )
84+
assert compile_count == 0

python/triton_kernels/triton_kernels/specialize.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,24 @@
55
import triton
66

77

8+
def cacheable(f):
9+
"""
10+
A decorator that allow you to write something of the form:
11+
12+
@cacheable
13+
def my_kernel(): return (expression dynamically defining a kernel)
14+
15+
such that it interacts gracefully with triton cache and preload.
16+
"""
17+
18+
g = f()
19+
g.fn.__name__ = f.__name__
20+
g.fn.__module__ = f.__module__
21+
g.fn.__qualname__ = f.__qualname__
22+
g._fn_name = f"{f.__module__}.{f.__qualname__}"
23+
return g
24+
25+
826
def define_kernel(src, module, attrs=None, **extra_globals):
927
"""
1028
Dynamically create a Triton function or kernel from a src string,

0 commit comments

Comments
 (0)