Skip to content

Commit d7e43ad

Browse files
authored
[kernels][tests] Use fresh_knobs when touching triton.knobs (#7687)
Follow up to #7682
1 parent 663e04e commit d7e43ad

File tree

5 files changed

+85
-79
lines changed

5 files changed

+85
-79
lines changed

python/test/conftest.py

Lines changed: 2 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
import os
21
import pytest
32
import tempfile
4-
from typing import Optional, Set
53

64

75
def pytest_configure(config):
@@ -26,49 +24,9 @@ def fresh_triton_cache():
2624
yield tmpdir
2725

2826

29-
def _fresh_knobs_impl(monkeypatch, skipped_attr: Optional[Set[str]] = None):
30-
from triton import knobs
31-
32-
if skipped_attr is None:
33-
skipped_attr = set()
34-
35-
knobs_map = {
36-
name: knobset
37-
for name, knobset in knobs.__dict__.items()
38-
if isinstance(knobset, knobs.base_knobs) and knobset != knobs.base_knobs and name not in skipped_attr
39-
}
40-
41-
# We store which variables we need to unset below in finally because
42-
# monkeypatch doesn't appear to reset variables that were never set
43-
# before the monkeypatch.delenv call below.
44-
env_to_unset = []
45-
prev_propagate_env = knobs.propagate_env
46-
47-
def fresh_function():
48-
nonlocal env_to_unset
49-
for name, knobset in knobs_map.items():
50-
setattr(knobs, name, knobset.copy().reset())
51-
for knob in knobset.knob_descriptors.values():
52-
if knob.key in os.environ:
53-
monkeypatch.delenv(knob.key, raising=False)
54-
else:
55-
env_to_unset.append(knob.key)
56-
knobs.propagate_env = True
57-
return knobs
58-
59-
def reset_function():
60-
for name, knobset in knobs_map.items():
61-
setattr(knobs, name, knobset)
62-
for k in env_to_unset:
63-
if k in os.environ:
64-
del os.environ[k]
65-
knobs.propagate_env = prev_propagate_env
66-
67-
return fresh_function, reset_function
68-
69-
7027
@pytest.fixture
7128
def fresh_knobs(monkeypatch):
29+
from triton._internal_testing import _fresh_knobs_impl
7230
fresh_function, reset_function = _fresh_knobs_impl(monkeypatch)
7331
try:
7432
yield fresh_function()
@@ -83,6 +41,7 @@ def fresh_knobs_except_libraries(monkeypatch):
8341
information from the environment as these may be
8442
needed to successfully compile kernels.
8543
"""
44+
from triton._internal_testing import _fresh_knobs_impl
8645
fresh_function, reset_function = _fresh_knobs_impl(monkeypatch, skipped_attr={"build", "nvidia", "amd"})
8746
try:
8847
yield fresh_function()

python/triton/_internal_testing.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import triton
66
import triton.language as tl
77
from triton import knobs
8+
from typing import Optional, Set, Union
89
import pytest
910

1011
from numpy.random import RandomState
11-
from typing import Optional, Union
1212
from triton.runtime.jit import TensorWrapper, reinterpret, type_canonicalisation_dict
1313

1414
int_dtypes = ['int8', 'int16', 'int32', 'int64']
@@ -202,3 +202,44 @@ def unwrap_tensor(t: Union[torch.Tensor, triton.runtime.jit.TensorWrapper]) -> t
202202
if isinstance(t, triton.runtime.jit.TensorWrapper):
203203
return t.base
204204
return t
205+
206+
207+
def _fresh_knobs_impl(monkeypatch, skipped_attr: Optional[Set[str]] = None):
208+
from triton import knobs
209+
210+
if skipped_attr is None:
211+
skipped_attr = set()
212+
213+
knobs_map = {
214+
name: knobset
215+
for name, knobset in knobs.__dict__.items()
216+
if isinstance(knobset, knobs.base_knobs) and knobset != knobs.base_knobs and name not in skipped_attr
217+
}
218+
219+
# We store which variables we need to unset below in finally because
220+
# monkeypatch doesn't appear to reset variables that were never set
221+
# before the monkeypatch.delenv call below.
222+
env_to_unset = []
223+
prev_propagate_env = knobs.propagate_env
224+
225+
def fresh_function():
226+
nonlocal env_to_unset
227+
for name, knobset in knobs_map.items():
228+
setattr(knobs, name, knobset.copy().reset())
229+
for knob in knobset.knob_descriptors.values():
230+
if knob.key in os.environ:
231+
monkeypatch.delenv(knob.key, raising=False)
232+
else:
233+
env_to_unset.append(knob.key)
234+
knobs.propagate_env = True
235+
return knobs
236+
237+
def reset_function():
238+
for name, knobset in knobs_map.items():
239+
setattr(knobs, name, knobset)
240+
for k in env_to_unset:
241+
if k in os.environ:
242+
del os.environ[k]
243+
knobs.propagate_env = prev_propagate_env
244+
245+
return fresh_function, reset_function

python/triton_kernels/tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,13 @@ def pytest_addoption(parser):
88
@pytest.fixture
99
def device(request):
1010
return request.config.getoption("--device")
11+
12+
13+
@pytest.fixture
14+
def fresh_knobs(monkeypatch):
15+
from triton._internal_testing import _fresh_knobs_impl
16+
fresh_function, reset_function = _fresh_knobs_impl(monkeypatch)
17+
try:
18+
yield fresh_function()
19+
finally:
20+
reset_function()

python/triton_kernels/tests/test_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ class Case:
255255
@pytest.mark.parametrize("is_persistent", [False, True])
256256
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas, is_persistent, n_expts_tot,
257257
n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile,
258-
device, opt_flags_scope):
258+
device, opt_flags_scope, fresh_knobs):
259259
# TODO: remove when Triton FP8 supports proper RTNE
260260
if is_cuda():
261261
if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9:

python/triton_kernels/tests/test_specialize.py

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def cacheable_kernel():
3838
return get_specialized_kernel()
3939

4040

41-
def test_cacheable(device):
41+
def test_cacheable(device, fresh_knobs):
4242
specialized_kernel = get_specialized_kernel()
4343

4444
specialization_data = None
@@ -53,36 +53,32 @@ def cache_hook(*args, **kwargs):
5353
fn_name = kwargs["fn"].name
5454
module_name = kwargs["fn"].module
5555

56-
prev_hook = triton.knobs.runtime.jit_cache_hook
57-
try:
58-
triton.knobs.runtime.jit_cache_hook = cache_hook
59-
o = torch.empty((1, ), dtype=torch.float32, device=device)
60-
k = specialized_kernel[(1, )](o, )
61-
hash = k.hash
62-
assert o.item() == 1.0
63-
assert module_name == "tests.test_specialize"
64-
assert fn_name == "cacheable_kernel"
65-
66-
compile_count = 0
67-
68-
def count_hook(*args, **kwargs):
69-
nonlocal compile_count
70-
compile_count += 1
71-
72-
triton.knobs.runtime.jit_cache_hook = count_hook
73-
# clear the cache
74-
specialized_kernel.device_caches.clear()
75-
76-
# retrieve the kernel from name and preload it.
77-
fn = retrieve_fn(module_name, fn_name)
78-
assert fn == specialized_kernel
79-
preload = fn.preload(specialization_data)
80-
assert compile_count == 1
81-
assert preload.hash == hash
82-
83-
# verify that we hit the cache.
84-
compile_count = 0
85-
specialized_kernel[(1, )](o, )
86-
assert compile_count == 0
87-
finally:
88-
triton.knobs.runtime.jit_cache_hook = prev_hook
56+
triton.knobs.runtime.jit_cache_hook = cache_hook
57+
o = torch.empty((1, ), dtype=torch.float32, device=device)
58+
k = specialized_kernel[(1, )](o, )
59+
hash = k.hash
60+
assert o.item() == 1.0
61+
assert module_name == "tests.test_specialize"
62+
assert fn_name == "cacheable_kernel"
63+
64+
compile_count = 0
65+
66+
def count_hook(*args, **kwargs):
67+
nonlocal compile_count
68+
compile_count += 1
69+
70+
triton.knobs.runtime.jit_cache_hook = count_hook
71+
# clear the cache
72+
specialized_kernel.device_caches.clear()
73+
74+
# retrieve the kernel from name and preload it.
75+
fn = retrieve_fn(module_name, fn_name)
76+
assert fn == specialized_kernel
77+
preload = fn.preload(specialization_data)
78+
assert compile_count == 1
79+
assert preload.hash == hash
80+
81+
# verify that we hit the cache.
82+
compile_count = 0
83+
specialized_kernel[(1, )](o, )
84+
assert compile_count == 0

0 commit comments

Comments
 (0)