|
9 | 9 | import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
|
10 | 10 | from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, MicroscalingCtx, FusedActivation, FnSpecs
|
11 | 11 | from triton_kernels.matmul_ogs import can_use_persistent_tma
|
12 |
| -from triton_kernels.matmul_ogs import matmul_ogs, matmul_ogs_torch |
| 12 | +from triton_kernels.matmul_ogs import matmul_ogs_set_idle_sms, matmul_ogs, matmul_ogs_torch |
13 | 13 | from triton_kernels.swiglu import swiglu, swiglu_fn, PrecisionConfig as SwiGLUPrecisionConfig
|
14 | 14 | # numerics utilities
|
15 | 15 | from triton_kernels.numerics import InFlexData, OutFlexData
|
@@ -444,6 +444,16 @@ def round_x(x, idx):
|
444 | 444 | tri_y_scale).abs() < 1e-10, f"ref_y_scale: {ref_y_scale}, tri_y_scale: {tri_y_scale.item()}"
|
445 | 445 |
|
446 | 446 |
|
| 447 | +def test_set_idle_sms(): |
| 448 | + from triton_kernels.matmul_ogs_details.opt_flags import make_opt_flags |
| 449 | + num_idle_sms = 24 |
| 450 | + matmul_ogs_set_idle_sms(num_idle_sms) |
| 451 | + flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), 0, 0, 0, None, False, False, |
| 452 | + 1) |
| 453 | + assert flags.is_persistent |
| 454 | + assert flags.idle_sms == num_idle_sms |
| 455 | + |
| 456 | + |
447 | 457 | @pytest.mark.parametrize("m, n, k, mode", [
|
448 | 458 | (1200, 704, 608, "ragged"),
|
449 | 459 | (800, 800, 400, "batched"),
|
|
0 commit comments