diff --git a/python/test/unit/intel/test_mxfp_matmul.py b/python/test/unit/intel/test_mxfp_matmul.py index e9de591541..a1b855fa00 100644 --- a/python/test/unit/intel/test_mxfp_matmul.py +++ b/python/test/unit/intel/test_mxfp_matmul.py @@ -36,9 +36,9 @@ def mxfp_matmul( # a_scale, b_scale, # M, N, K, # stride_scale, # - stride_am, stride_ak, # - stride_bk, stride_bn, # - stride_cm, stride_cn, # + stride_am: tl.constexpr, stride_ak: tl.constexpr, # + stride_bk: tl.constexpr, stride_bn: tl.constexpr, # + stride_cm: tl.constexpr, stride_cn: tl.constexpr, # DTYPE_A: tl.constexpr, # DTYPE_B: tl.constexpr, # BLOCK_M: tl.constexpr, # diff --git a/python/test/unit/language/test_matmul.py b/python/test/unit/language/test_matmul.py index 9630aef886..49408862c7 100644 --- a/python/test/unit/language/test_matmul.py +++ b/python/test/unit/language/test_matmul.py @@ -1252,9 +1252,9 @@ def create_operand(dtype: str, size0: int, size1: int, k_dim: int, transpose: bo kernel_kwargs = {} if is_hip(): kernel_kwargs["matrix_instr_nonkdim"] = nonKDim - if is_xpu() and (128, 256, 256) == (BLOCK_M, BLOCK_N, BLOCK_K) and not CONST_SCALE and not PACK_B_ALONG_K: - kernel_kwargs["num_warps"] = 8 if is_xpu(): + # since the block size are big we use num_warps = 32 to avoid pressure problems. + kernel_kwargs["num_warps"] = 32 kernel_kwargs["grf_mode"] = "256" out = mxfp8_mxfp4_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, stride_scale, a.stride(0), a.stride(1), b.stride(0), b.stride(1), output.stride(0), output.stride(1), not CONST_SCALE,