Skip to content

Commit 125c165

Browse files
authored
Avoid unnecessary overhead in hot codepath (#5409)
Also fixes an issue where the kernel would 'remember' the first backend encountered and assume that for subsequent compilations
1 parent acc25d9 commit 125c165

File tree

2 files changed

+36
-36
lines changed

2 files changed

+36
-36
lines changed

python/test/unit/runtime/test_cache.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def kernel(X, i: tl.int32):
199199
kernel[(1, )](x, 8)
200200
kernel[(1, )](x, 16)
201201
kernel[(1, )](x, 17)
202-
assert len(kernel.cache[device]) == 3
202+
assert len(kernel.device_caches[device][0]) == 3
203203

204204

205205
GLOBAL_DEFAULT_ARG = 1
@@ -223,7 +223,7 @@ def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG):
223223
assert x == torch.ones_like(x)
224224

225225
device = getattr(torch, device).current_device()
226-
assert len(kernel.cache[device]) == 1
226+
assert len(kernel.device_caches[device][0]) == 1
227227

228228

229229
GLOBAL_VAR: tl.constexpr = 1
@@ -416,13 +416,13 @@ def kernel_add(a, b, o, N: tl.constexpr):
416416
32,
417417
]
418418
device = getattr(torch, device).current_device()
419-
assert len(kernel_add.cache[device]) == 0
419+
assert len(kernel_add.device_caches[device][0]) == 0
420420
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
421-
assert len(kernel_add.cache[device]) == 1
421+
assert len(kernel_add.device_caches[device][0]) == 1
422422
kernel_add.warmup(*args, grid=(1, ))
423-
assert len(kernel_add.cache[device]) == 1
423+
assert len(kernel_add.device_caches[device][0]) == 1
424424
kernel_add.warmup(*args, grid=(1, ))
425-
assert len(kernel_add.cache[device]) == 1
425+
assert len(kernel_add.device_caches[device][0]) == 1
426426

427427

428428
def test_jit_debug(device) -> None:
@@ -433,12 +433,12 @@ def kernel(tmp):
433433

434434
device = getattr(torch, device).current_device()
435435
tmp = torch.tensor([1], dtype=torch.int32, device=device)
436-
assert len(kernel.cache[device]) == 0
436+
assert len(kernel.device_caches[device][0]) == 0
437437
kernel[(1, )](tmp, debug=False)
438-
assert len(kernel.cache[device]) == 1
438+
assert len(kernel.device_caches[device][0]) == 1
439439
kernel[(1, )](tmp, debug=True)
440-
assert len(kernel.cache[device]) == 2
441-
bins = list(kernel.cache[device].values())
440+
assert len(kernel.device_caches[device][0]) == 2
441+
bins = list(kernel.device_caches[device][0].values())
442442
assert bins[0].asm['ttir'] != bins[1].asm['ttir']
443443

444444

@@ -455,18 +455,18 @@ def kernel_add_device(a, b, o, N: tl.constexpr):
455455
add_fn(a, b, o, N)
456456

457457
device = getattr(torch, device).current_device()
458-
assert len(kernel_add_device.cache[device]) == 0
458+
assert len(kernel_add_device.device_caches[device][0]) == 0
459459
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
460-
assert len(kernel_add_device.cache[device]) == 1
461-
bins = list(kernel_add_device.cache[device].values())
460+
assert len(kernel_add_device.device_caches[device][0]) == 1
461+
bins = list(kernel_add_device.device_caches[device][0].values())
462462
inline_ttir = bins[0].asm['ttir']
463463
add_fn.noinline = True
464464
add_fn.hash = None
465465
kernel_add_device.hash = None
466-
kernel_add_device.cache[device].clear()
466+
kernel_add_device.device_caches[device][0].clear()
467467
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
468-
assert len(kernel_add_device.cache[device]) == 1
469-
bins = list(kernel_add_device.cache[device].values())
468+
assert len(kernel_add_device.device_caches[device][0]) == 1
469+
bins = list(kernel_add_device.device_caches[device][0].values())
470470
noinline_ttir = bins[0].asm['ttir']
471471
assert inline_ttir != noinline_ttir
472472

@@ -514,12 +514,12 @@ def cache_hook(*args, **kwargs):
514514

515515
# clear the cache
516516
shutil.rmtree(fresh_triton_cache)
517-
kernel_add.cache[device].clear()
517+
kernel_add.device_caches[device][0].clear()
518518

519519
# preload the kernel
520520
kernel_preload = kernel_add.preload(specialization_data)
521521
assert kernel_preload.hash == hash
522-
assert len(kernel_add.cache[device]) == 1
522+
assert len(kernel_add.device_caches[device][0]) == 1
523523

524524
# we should hit the cache and not compile anything
525525
counter = 0
@@ -532,7 +532,7 @@ def inc_counter(*args, **kwargs):
532532
final_kernel = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, ))
533533
JITFunction.cache_hook = None
534534
assert counter == 0
535-
assert len(kernel_add.cache[device]) == 1
535+
assert len(kernel_add.device_caches[device][0]) == 1
536536
assert final_kernel.hash == hash
537537

538538
# test that we can't preload a mismatched kernel
@@ -572,7 +572,7 @@ def compiled_hook(*args, **kwargs):
572572
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, ))
573573
assert specialization_data is not None and specialization_data_compiled == specialization_data
574574
assert is_warmup is True
575-
assert key in kernel_add.cache[getattr(torch, device).current_device()]
575+
assert key in kernel_add.device_caches[getattr(torch, device).current_device()][0]
576576

577577

578578
@pytest.mark.skipif(reason="within_2g is a HIP specific optimization", condition=not is_hip())

python/triton/runtime/jit.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -547,47 +547,49 @@ def add_pre_run_hook(self, hook):
547547
assert callable(hook)
548548
self.pre_run_hooks.append(hook)
549549

550-
def create_binder(self, backend):
550+
def create_binder(self):
551551
"""
552552
Precompute as much as possible.
553553
"""
554554
from ..compiler import CompiledKernel, compile, ASTSource, make_backend
555+
target = driver.active.get_current_target()
556+
backend = make_backend(target)
555557
self.CompiledKernel = CompiledKernel
556558
self.compile = compile
557559
self.ASTSource = ASTSource
558-
self.make_backend = make_backend
559-
self.binder = create_function_from_signature(self.signature, self.params, backend)
560+
binder = create_function_from_signature(self.signature, self.params, backend)
560561
self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr]
561562
self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr]
562563
self.specialised_indices = [
563564
i for (i, p) in enumerate(self.params) if (not p.do_not_specialize) and (not p.is_constexpr)
564565
]
566+
return [target, backend, binder]
565567

566568
def run(self, *args, grid, warmup, **kwargs):
567569
kwargs["debug"] = kwargs.get("debug", self.debug) or os.environ.get("TRITON_DEBUG", "0") == "1"
568570

569571
# parse options
570-
from ..compiler import make_backend
571572
device = driver.active.get_current_device()
572573
stream = driver.active.get_current_stream(device)
573-
target = driver.active.get_current_target()
574-
backend = make_backend(target)
575574

576575
# Execute pre run hooks with args and kwargs
577576
for hook in self.pre_run_hooks:
578577
hook(*args, **kwargs)
579578

580-
if self.binder is None:
581-
self.create_binder(backend)
582-
583-
bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs)
579+
# This is a length-4 list [kernel_cache, target, backend, binder]:
580+
device_cache = self.device_caches[device]
581+
if len(device_cache) == 1:
582+
device_cache[1:] = self.create_binder()
583+
bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = device_cache[3](*args, **kwargs)
584584

585585
# compute cache key
586586
key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs))
587-
kernel = self.cache[device].get(key, None)
587+
kernel = device_cache[0].get(key, None)
588588

589589
if kernel is None:
590590
# Kernel is not cached; we have to compile.
591+
target = device_cache[1]
592+
backend = device_cache[2]
591593
options = backend.parse_options(kwargs)
592594

593595
# deprecated arguments
@@ -619,7 +621,7 @@ def run(self, *args, grid, warmup, **kwargs):
619621
# compile the kernel
620622
src = self.ASTSource(self, signature, constexprs, attrs)
621623
kernel = self.compile(src, target=target, options=options.__dict__)
622-
self.cache[device][key] = kernel
624+
device_cache[0][key] = kernel
623625
self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=False)
624626

625627
# Check that used global values have not changed.
@@ -659,8 +661,6 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o
659661
self.repr = lambda _: fn.__name__ if repr is None else repr(_)
660662
self.launch_metadata = launch_metadata
661663

662-
self.binder = None
663-
664664
self.params = []
665665
for i, param in enumerate(self.signature.parameters.values()):
666666
dns = i in do_not_specialize or param.name in do_not_specialize
@@ -671,7 +671,7 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o
671671
self.src = textwrap.dedent(inspect.getsource(fn))
672672
self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():]
673673
# cache of just-in-time compiled kernels
674-
self.cache = defaultdict(dict)
674+
self.device_caches = defaultdict(lambda: [{}])
675675
self.hash = None
676676

677677
# Map of global variables used by the function and any functions it
@@ -742,7 +742,7 @@ def preload(self, specialization_data):
742742
}
743743
key = deserialized_obj['key']
744744
kernel = compile(src, None, options)
745-
self.cache[device][key] = kernel
745+
self.device_caches[device][0][key] = kernel
746746
return kernel
747747

748748
# we do not parse `src` in the constructor because

0 commit comments

Comments
 (0)