Skip to content

Commit 6b2fa6c

Browse files
authored
Implement fresh_triton_cache fixture through compilation.always_compile (#4363)
The main idea is that for tests where we need the cache to be generated each time, we can set this behavior via a special environment variable (`TRITON_ALWAYS_COMPILE` or corresponding knobs), instead of creating a new temporary folder and deleting it each time, which is problematic on Windows. It seems like this solution is more lightweight, maybe it will be possible to upstream it. --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 1b520ae commit 6b2fa6c

File tree

6 files changed

+30
-29
lines changed

6 files changed

+30
-29
lines changed

python/test/regression/conftest.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import os
21
import pytest
3-
import tempfile
42

53

64
def pytest_addoption(parser):
@@ -14,9 +12,7 @@ def device(request):
1412

1513
@pytest.fixture
1614
def fresh_triton_cache():
17-
with tempfile.TemporaryDirectory() as tmpdir:
18-
try:
19-
os.environ["TRITON_CACHE_DIR"] = tmpdir
20-
yield tmpdir
21-
finally:
22-
os.environ.pop("TRITON_CACHE_DIR", None)
15+
from triton import knobs
16+
with knobs.compilation.scope():
17+
knobs.compilation.always_compile = True
18+
yield

python/test/unit/conftest.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import sys
33
import pathlib
44
import pytest
5-
import tempfile
65
from typing import Optional, Set
6+
import contextlib
77

88

99
def pytest_configure(config):
@@ -69,11 +69,23 @@ def device(request):
6969

7070
@pytest.fixture
7171
def fresh_triton_cache():
72-
with tempfile.TemporaryDirectory() as tmpdir:
73-
from triton import knobs
74-
with knobs.cache.scope():
75-
knobs.cache.dir = tmpdir
76-
yield tmpdir
72+
from triton import knobs
73+
with knobs.compilation.scope():
74+
knobs.compilation.always_compile = True
75+
yield
76+
77+
78+
@pytest.fixture
79+
def fresh_triton_cache_scope():
80+
from triton import knobs
81+
82+
@contextlib.contextmanager
83+
def fresh_cache():
84+
with knobs.compilation.scope():
85+
knobs.compilation.always_compile = True
86+
yield
87+
88+
yield fresh_cache
7789

7890

7991
def _fresh_knobs_impl(monkeypatch, skipped_attr: Optional[Set[str]] = None):

python/test/unit/runtime/test_cache.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import importlib.util
22
import itertools
33
import os
4-
import shutil
54
import pathlib
65

76
import pytest
@@ -495,7 +494,6 @@ def cache_hook(*args, **kwargs):
495494
assert specialization_data is not None
496495

497496
# clear the cache
498-
shutil.rmtree(fresh_triton_cache)
499497
kernel_add.device_caches[device][0].clear()
500498

501499
# preload the kernel

python/test/unit/runtime/test_compilation_listener.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from triton.knobs import CompileTimes
66
from triton.compiler.compiler import ASTSource, IRSource
77

8-
from typing import Any, Union
8+
from typing import Any, Union, Callable
99

1010
import torch
1111

@@ -17,7 +17,7 @@ def cumsum_kernel(ptr):
1717
tl.store(block, tl.cumsum(x, 0))
1818

1919

20-
def test_compile_stats(device: str, fresh_knobs_except_libraries: Any, fresh_triton_cache: str) -> None:
20+
def test_compile_stats(device: str, fresh_knobs_except_libraries: Any, fresh_triton_cache_scope: Callable) -> None:
2121
captured: Union[tuple[Union[ASTSource, IRSource], dict[str, Any], dict[str, Any], CompileTimes, bool], None] = None
2222

2323
def compile_listener(src: Union[ASTSource, IRSource], metadata: dict[str, str], metadata_group: dict[str, Any],
@@ -29,7 +29,8 @@ def compile_listener(src: Union[ASTSource, IRSource], metadata: dict[str, str],
2929
fresh_knobs_except_libraries.compilation.listener = compile_listener
3030

3131
x = torch.randn(4, device=device)
32-
cumsum_kernel[(1, )](x)
32+
with fresh_triton_cache_scope():
33+
cumsum_kernel[(1, )](x)
3334

3435
assert captured is not None
3536

python/test/unit/runtime/test_subproc.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import multiprocessing
2-
import shutil
32

43
import triton
54
import triton.language as tl
@@ -87,7 +86,6 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None:
8786
compile_empty_kernel_with_gc()
8887

8988
# stage 2.p
90-
shutil.rmtree(fresh_triton_cache)
9189
mp_ctx = multiprocessing.get_context(start_method)
9290
proc = mp_ctx.Process(target=compile_empty_kernel_with_gc)
9391

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import os
21
import pytest
3-
import tempfile
42

53

64
def pytest_addoption(parser):
@@ -14,9 +12,7 @@ def device(request):
1412

1513
@pytest.fixture
1614
def fresh_triton_cache():
17-
with tempfile.TemporaryDirectory() as tmpdir:
18-
try:
19-
os.environ["TRITON_CACHE_DIR"] = tmpdir
20-
yield tmpdir
21-
finally:
22-
os.environ.pop("TRITON_CACHE_DIR", None)
15+
from triton import knobs
16+
with knobs.compilation.scope():
17+
knobs.compilation.always_compile = True
18+
yield

0 commit comments

Comments
 (0)