Skip to content

Commit 9311c40

Browse files
authored
[TEST] Do not pass tl.constexpr as an argument to a kernel and fix leaking hook (#7682)
It is breaking CI
1 parent eb66546 commit 9311c40

File tree

2 files changed

+35
-31
lines changed

2 files changed

+35
-31
lines changed

python/triton_kernels/tests/test_specialize.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -53,32 +53,36 @@ def cache_hook(*args, **kwargs):
5353
fn_name = kwargs["fn"].name
5454
module_name = kwargs["fn"].module
5555

56-
triton.knobs.runtime.jit_cache_hook = cache_hook
57-
o = torch.empty((1, ), dtype=torch.float32, device=device)
58-
k = specialized_kernel[(1, )](o, )
59-
hash = k.hash
60-
assert o.item() == 1.0
61-
assert module_name == "tests.test_specialize"
62-
assert fn_name == "cacheable_kernel"
63-
64-
compile_count = 0
65-
66-
def count_hook(*args, **kwargs):
67-
nonlocal compile_count
68-
compile_count += 1
69-
70-
triton.knobs.runtime.jit_cache_hook = count_hook
71-
# clear the cache
72-
specialized_kernel.device_caches.clear()
73-
74-
# retrieve the kernel from name and preload it.
75-
fn = retrieve_fn(module_name, fn_name)
76-
assert fn == specialized_kernel
77-
preload = fn.preload(specialization_data)
78-
assert compile_count == 1
79-
assert preload.hash == hash
80-
81-
# verify that we hit the cache.
82-
compile_count = 0
83-
specialized_kernel[(1, )](o, )
84-
assert compile_count == 0
56+
prev_hook = triton.knobs.runtime.jit_cache_hook
57+
try:
58+
triton.knobs.runtime.jit_cache_hook = cache_hook
59+
o = torch.empty((1, ), dtype=torch.float32, device=device)
60+
k = specialized_kernel[(1, )](o, )
61+
hash = k.hash
62+
assert o.item() == 1.0
63+
assert module_name == "tests.test_specialize"
64+
assert fn_name == "cacheable_kernel"
65+
66+
compile_count = 0
67+
68+
def count_hook(*args, **kwargs):
69+
nonlocal compile_count
70+
compile_count += 1
71+
72+
triton.knobs.runtime.jit_cache_hook = count_hook
73+
# clear the cache
74+
specialized_kernel.device_caches.clear()
75+
76+
# retrieve the kernel from name and preload it.
77+
fn = retrieve_fn(module_name, fn_name)
78+
assert fn == specialized_kernel
79+
preload = fn.preload(specialization_data)
80+
assert compile_count == 1
81+
assert preload.hash == hash
82+
83+
# verify that we hit the cache.
84+
compile_count = 0
85+
specialized_kernel[(1, )](o, )
86+
assert compile_count == 0
87+
finally:
88+
triton.knobs.runtime.jit_cache_hook = prev_hook

python/triton_kernels/triton_kernels/numerics_details/mxfp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis
5252
kernel_scale = out_scale.view(-1, out_scale.shape[-1])
5353

5454
BLOCK_OUT_DIM = 128
55-
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE
55+
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
5656
grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM)
5757
grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM)
5858

@@ -93,7 +93,7 @@ def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dty
9393
reshaped_tensor = tensor.view(-1, tensor.shape[-1])
9494
reshaped_scale = scale.view(-1, scale.shape[-1])
9595
BLOCK_OUT_DIM = 128
96-
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE
96+
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
9797
blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM)
9898
blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM)
9999
_upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)](reshaped_out, *reshaped_out.stride(), reshaped_scale,

0 commit comments

Comments
 (0)