Skip to content

Commit 993c8da

Browse files
authored
[KERNELS] added idle_sms constraint for persistent matmul_ogs (#7184)
it is currently ignored when the kernel is non-persistent. There must exist a better way to expose this.
1 parent 93eb090 commit 993c8da

File tree

1 file changed

+6
-2
lines changed
  • python/triton_kernels/triton_kernels/matmul_ogs_details

1 file changed

+6
-2
lines changed

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class OptFlags:
2121
split_k: int
2222
fused_scatter: bool
2323
is_persistent: bool
24+
idle_sms: int
2425
epilogue_subtile: int | None
2526
arch: str
2627
target_kernel_kwargs: dict
@@ -116,6 +117,7 @@ def make_default_opt_flags_amd(
116117
split_k=split_k,
117118
fused_scatter=constraints.get('fused_scatter', False),
118119
is_persistent=is_persistent,
120+
idle_sms=0,
119121
epilogue_subtile=constraints.get('epilogue_subtile', None),
120122
arch=None,
121123
target_kernel_kwargs=target_kernel_kwargs,
@@ -140,7 +142,7 @@ def make_default_opt_flags_nvidia(
140142
epilogue_effective_itemsize,
141143
constraints,
142144
):
143-
constraints_supported = ["block_m", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile", "num_stages"]
145+
constraints_supported = ["block_m", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile", "num_stages", "idle_sms"]
144146
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
145147
# tokens per expert
146148
if routing_data is None:
@@ -236,6 +238,7 @@ def make_default_opt_flags_nvidia(
236238
epilogue_subtile=epilogue_subtile,
237239
arch=arch,
238240
target_kernel_kwargs=dict(),
241+
idle_sms=constraints.get("idle_sms", 0),
239242
)
240243
# check constraints
241244
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
@@ -283,7 +286,8 @@ def make_opt_flags(
283286
return _opt_flags
284287
args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, microscaling_ctx, m, n, k,
285288
routing_data, can_use_persistent_tma, can_use_fused_scatter,
286-
enforce_bitwise_invariance, epilogue_effective_itemsize, _opt_flags_constraints]
289+
enforce_bitwise_invariance, epilogue_effective_itemsize,
290+
_opt_flags_constraints]
287291
backend = triton.runtime.driver.active.get_current_target().backend
288292
if backend == "hip":
289293
return make_default_opt_flags_amd(*args)

0 commit comments

Comments
 (0)