Skip to content

Commit 8b7074c

Browse files
authored
[KERNELS] no longer enforce persistent when is used (#7214)
1 parent 0e9706c commit 8b7074c

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -448,9 +448,8 @@ def test_set_idle_sms():
448448
from triton_kernels.matmul_ogs_details.opt_flags import make_opt_flags
449449
num_idle_sms = 24
450450
matmul_ogs_set_idle_sms(num_idle_sms)
451-
flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), 0, 0, 0, None, False, False,
452-
1)
453-
assert flags.is_persistent
451+
flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \
452+
1024, 1024, 1024, None, True, False, 1)
454453
assert flags.idle_sms == num_idle_sms
455454

456455

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,10 @@ def _create_tma_descriptors(
614614
return x_tensor_or_desc, w_desc_and_transpose, mx_desc_and_transpose
615615

616616
def matmul_ogs_set_idle_sms(num_idle_sms):
617-
update_opt_flags_constraints({"is_persistent": True, "idle_sms": num_idle_sms})
617+
"""
618+
persistent kernels will leave `num_idle_sms` idle
619+
"""
620+
update_opt_flags_constraints({"idle_sms": num_idle_sms})
618621

619622
def matmul_ogs(x, w, bias,
620623
routing_data: RoutingData | None = None,

0 commit comments

Comments
 (0)