Skip to content

Commit b3ce5fb

Browse files
[TEST] Modify mxfp tests to allow faster runtime (#5602)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent f7f1e8f commit b3ce5fb

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

python/test/unit/intel/test_mxfp_matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def mxfp_matmul( #
3636
a_scale, b_scale, #
3737
M, N, K, #
3838
stride_scale, #
39-
stride_am, stride_ak, #
40-
stride_bk, stride_bn, #
41-
stride_cm, stride_cn, #
39+
stride_am: tl.constexpr, stride_ak: tl.constexpr, #
40+
stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
41+
stride_cm: tl.constexpr, stride_cn: tl.constexpr, #
4242
DTYPE_A: tl.constexpr, #
4343
DTYPE_B: tl.constexpr, #
4444
BLOCK_M: tl.constexpr, #

python/test/unit/language/test_matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,9 +1252,9 @@ def create_operand(dtype: str, size0: int, size1: int, k_dim: int, transpose: bo
12521252
kernel_kwargs = {}
12531253
if is_hip():
12541254
kernel_kwargs["matrix_instr_nonkdim"] = nonKDim
1255-
if is_xpu() and (128, 256, 256) == (BLOCK_M, BLOCK_N, BLOCK_K) and not CONST_SCALE and not PACK_B_ALONG_K:
1256-
kernel_kwargs["num_warps"] = 8
12571255
if is_xpu():
1256+
# since the block size are big we use num_warps = 32 to avoid pressure problems.
1257+
kernel_kwargs["num_warps"] = 32
12581258
kernel_kwargs["grf_mode"] = "256"
12591259
out = mxfp8_mxfp4_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, stride_scale, a.stride(0), a.stride(1),
12601260
b.stride(0), b.stride(1), output.stride(0), output.stride(1), not CONST_SCALE,

0 commit comments

Comments
 (0)