Skip to content

Commit 09649e2

Browse files
authored
Include constexprs in cache keys, fixes #7322, reverts #7344 (#7348)
See #7332, this reinstates the change where constexprs are included in kernel cache keys, but only includes constexprs, and doesn't include the `id(var_dict)` which is nonpredictable. (from: `self.used_global_vals[(node.id, id(var_dict))] = (copy.copy(val), var_dict)`) # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent cb7197e commit 09649e2

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

python/test/unit/runtime/test_cache.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,25 @@ def kernel():
400400
assert not kernel.used_global_vals
401401

402402

403+
def test_constexpr_cache_invalidation_recreated(device):
404+
405+
def test_run(val):
406+
VAL = tl.constexpr(val)
407+
408+
@triton.jit
409+
def kernel(out):
410+
tl.store(out, VAL)
411+
412+
out = torch.zeros(1, device=device)
413+
kernel[(1, )](out)
414+
return out.item()
415+
416+
assert test_run(123) == 123
417+
assert test_run(123) == 123
418+
assert test_run(1234) == 1234
419+
assert test_run(1234) == 1234
420+
421+
403422
def test_jit_warmup_cache(device) -> None:
404423

405424
@triton.jit

python/triton/runtime/jit.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,12 @@ def cache_key(self):
696696
dependencies_finder.visit(self.parse())
697697
self.hash = dependencies_finder.ret + str(self.starting_line_number)
698698
self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items()))
699+
700+
from triton.language.core import constexpr
701+
self.hash += str([(name, val)
702+
for (name, _), (val, _) in self.used_global_vals.items()
703+
if isinstance(val, constexpr)])
704+
self.hash = hashlib.sha256(self.hash.encode("utf-8")).hexdigest()
699705
return self.hash
700706

701707
@property

0 commit comments

Comments
 (0)