diff --git a/python/perf-kernels/tools/plot-layout/plot_layout.py b/python/perf-kernels/tools/plot-layout/plot_layout.py index f506b8e7d88f..63eb4b81ad3b 100644 --- a/python/perf-kernels/tools/plot-layout/plot_layout.py +++ b/python/perf-kernels/tools/plot-layout/plot_layout.py @@ -559,11 +559,18 @@ def main(): print(f"Plotting dot operation with shapes=M{M}-N{N}-K{K},{kWidth=},{kGroup=},{warpsPerCTA=},{CTAShape=}") assert M != 0 and CTAShape[0] <= M and M % CTAShape[0] == 0, "bad tensor dimension M" assert N != 0 and CTAShape[1] <= N and N % CTAShape[1] == 0, "bad tensor dimension N" + assert K != 0, "bad tensor dimension K" if isMixedPrecBtwF8AndF4OrF6(dtype_a, dtype_b): ## In the case of mixed precision between 8-bit and 4 or 6-bit, ## ignore kWidth and kGroup since inA and inB have different kWidth and kGroup values - kDim = 128 - assert K != 0 and K % kDim == 0, f"one mfma instruction requires {kDim:.0f} elements along k dim but BLOCK_K = {K}" + if mfmaNonKDim == 16: + kDim = 128 + elif mfmaNonKDim == 32: + kDim = 64 + else: + raise NotImplementedError("scaled dot only supports 32x32x64 or 16x16x128 for now") + assert K % kDim == 0, \ + f"one mfma instruction requires multiple of {kDim:.0d} elements along k dim but BLOCK_K = {K}" kpack = 1 CBSZ = matrixFormatTable[dtype_b] if trans else matrixFormatTable[dtype_a] BLGP = matrixFormatTable[dtype_a] if trans else matrixFormatTable[dtype_b] @@ -573,7 +580,7 @@ def main(): plot_scale = scale else: kDim = kWidth * kGroup * 64 / mfmaNonKDim - assert K != 0 and K % kDim == 0, f"one mfma instruction requires {kDim:.0f} elements along k dim but BLOCK_K = {K}" + assert K == kDim, f"one mfma instruction requires multiple of {kDim:.0d} elements along k dim but BLOCK_K = {K}" mfma_inst_str, kpack, CBSZ, BLGP, plot_scale = checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtype_a, dtype_b, trans, scale) isMixed864 = False