Skip to content

Commit 65167dc

Browse files
authored
[KERNELS] added option and test to set idle sms in matmul_ogs (#7210)
1 parent 39b8ead commit 65167dc

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
1010
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, MicroscalingCtx, FusedActivation, FnSpecs
1111
from triton_kernels.matmul_ogs import can_use_persistent_tma
12-
from triton_kernels.matmul_ogs import matmul_ogs, matmul_ogs_torch
12+
from triton_kernels.matmul_ogs import matmul_ogs_set_idle_sms, matmul_ogs, matmul_ogs_torch
1313
from triton_kernels.swiglu import swiglu, swiglu_fn, PrecisionConfig as SwiGLUPrecisionConfig
1414
# numerics utilities
1515
from triton_kernels.numerics import InFlexData, OutFlexData
@@ -444,6 +444,16 @@ def round_x(x, idx):
444444
tri_y_scale).abs() < 1e-10, f"ref_y_scale: {ref_y_scale}, tri_y_scale: {tri_y_scale.item()}"
445445

446446

447+
def test_set_idle_sms():
448+
from triton_kernels.matmul_ogs_details.opt_flags import make_opt_flags
449+
num_idle_sms = 24
450+
matmul_ogs_set_idle_sms(num_idle_sms)
451+
flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), 0, 0, 0, None, False, False,
452+
1)
453+
assert flags.is_persistent
454+
assert flags.idle_sms == num_idle_sms
455+
456+
447457
@pytest.mark.parametrize("m, n, k, mode", [
448458
(1200, 704, 608, "ragged"),
449459
(800, 800, 400, "batched"),

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .matmul_ogs_details._matmul_ogs import _matmul_ogs
1414
from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
1515
from .matmul_ogs_details._finalize_matmul import _finalize_matmul
16-
from .matmul_ogs_details.opt_flags import make_opt_flags, OptFlags
16+
from .matmul_ogs_details.opt_flags import make_opt_flags, OptFlags, update_opt_flags_constraints
1717
from .matmul_ogs_details.fast_contiguous import fast_contiguous
1818
from .numerics_details.mxfp import SwizzlingType
1919
from .specialize import specialize
@@ -613,6 +613,8 @@ def _create_tma_descriptors(
613613

614614
return x_tensor_or_desc, w_desc_and_transpose, mx_desc_and_transpose
615615

616+
def matmul_ogs_set_idle_sms(num_idle_sms):
617+
update_opt_flags_constraints({"is_persistent": True, "idle_sms": num_idle_sms})
616618

617619
def matmul_ogs(x, w, bias,
618620
routing_data: RoutingData | None = None,

0 commit comments

Comments
 (0)