Skip to content

Commit 478030b

Browse files
authored
[TEST] Reset jit_cache_hook and jit_cache_hook in the fresh_triton_cache fixture (#7885)
These are not environment variables and thereby cannot be automatically reset by `knobs.reset`
1 parent 0e3bf60 commit 478030b

File tree

4 files changed

+14
-5
lines changed

4 files changed

+14
-5
lines changed

python/test/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def device(request):
1919
def fresh_triton_cache():
2020
with tempfile.TemporaryDirectory() as tmpdir:
2121
from triton import knobs
22-
with knobs.cache.scope():
22+
23+
with knobs.cache.scope(), knobs.runtime.scope():
2324
knobs.cache.dir = tmpdir
2425
yield tmpdir
2526

python/test/unit/runtime/test_cache.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,6 @@ def inc_counter(*args, **kwargs):
583583

584584
triton.knobs.runtime.jit_cache_hook = inc_counter
585585
final_kernel = kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, tl.float32, grid=(1, ))
586-
triton.knobs.runtime.jit_cache_hook = None
587586
assert counter == 0
588587
assert len(kernel_add.device_caches[device][0]) == 1
589588
assert final_kernel.hash == hash
@@ -692,8 +691,6 @@ def func4(x, y):
692691
def kernel(Y, fn: tl.constexpr, fn_args):
693692
tl.store(Y, fn(*fn_args))
694693

695-
triton.knobs.runtime.jit_cache_hook = None
696-
triton.knobs.runtime.jit_post_compile_hook = None
697694
y = torch.zeros((5, ), dtype=torch.int32, device=device)
698695
kernel[(1, )](y[0], func1, tuple())
699696
kernel[(1, )](y[1], func2, tuple())

python/triton_kernels/tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import tempfile
23

34

45
def pytest_addoption(parser):
@@ -18,3 +19,13 @@ def fresh_knobs():
1819
yield fresh_function()
1920
finally:
2021
reset_function()
22+
23+
24+
@pytest.fixture
25+
def fresh_triton_cache():
26+
with tempfile.TemporaryDirectory() as tmpdir:
27+
from triton import knobs
28+
29+
with knobs.cache.scope(), knobs.runtime.scope():
30+
knobs.cache.dir = tmpdir
31+
yield tmpdir

python/triton_kernels/tests/test_specialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def cacheable_kernel():
3838
return get_specialized_kernel()
3939

4040

41-
def test_cacheable(device, fresh_knobs):
41+
def test_cacheable(device, fresh_triton_cache):
4242
specialized_kernel = get_specialized_kernel()
4343

4444
specialization_data = None

0 commit comments

Comments
 (0)