|
| 1 | +from dataclasses import dataclass |
| 2 | +from pathlib import Path |
| 3 | + |
| 4 | + |
| 5 | +@dataclass |
| 6 | +class DotConfig: |
| 7 | + mfmaNonKDim: int |
| 8 | + kWidth: int |
| 9 | + kGroup: int |
| 10 | + trans: int |
| 11 | + warpsPerCTA: tuple |
| 12 | + |
| 13 | + |
| 14 | +matrixFormatTable = {'fp8': 0, 'bf8': 1, 'fp6': 2, 'bf6': 3, 'f4': 4} |
| 15 | + |
| 16 | + |
| 17 | +def matrixFormat(dtypeA, dtypeB): |
| 18 | + """ |
| 19 | + return CBSZ and BLGP according to data types |
| 20 | + b000: E4M3(FP8) |
| 21 | + b001: E5M2(BF8) |
| 22 | + b010: E2M3(FP6) |
| 23 | + b011: E3M2(BF6) |
| 24 | + b100: E2M1(FP4) |
| 25 | + """ |
| 26 | + return matrixFormatTable[dtypeA], matrixFormatTable[dtypeB] |
| 27 | + |
| 28 | + |
| 29 | +def isType4Or6Bit(dtype): |
| 30 | + return dtype == 'fp6' or dtype == 'bf6' or dtype == 'f4' |
| 31 | + |
| 32 | + |
| 33 | +def isType8BitFloat(dtype): |
| 34 | + return dtype == 'fp8' or dtype == 'bf8' |
| 35 | + |
| 36 | + |
| 37 | +def isType16Bit(dtype): |
| 38 | + return dtype == 'bf16' or dtype == 'fp16' |
| 39 | + |
| 40 | + |
| 41 | +def isMixedPrecType(dtype): |
| 42 | + return isType8BitFloat(dtype) or isType4Or6Bit(dtype) |
| 43 | + |
| 44 | + |
| 45 | +def isMixedPrecBtwF8AndF4OrF6(dtypeA, dtypeB): |
| 46 | + return (isType8BitFloat(dtypeA) and isType4Or6Bit(dtypeB)) or \ |
| 47 | + (isType8BitFloat(dtypeB) and isType4Or6Bit(dtypeA)) |
| 48 | + |
| 49 | + |
| 50 | +def draw_dot_layout_cmd(M, N, K, dtypeA, dtypeB, mfma_inst_str, isMixed864, plot_scale, dotConfig): |
| 51 | + mfmaNonKDim = dotConfig.mfmaNonKDim |
| 52 | + warpsPerCTA = dotConfig.warpsPerCTA |
| 53 | + trans = 1 if dotConfig.trans else 0 |
| 54 | + kWidth = dotConfig.kWidth |
| 55 | + kGroup = dotConfig.kGroup |
| 56 | + scaleLabel = 0.7 if (kWidth == 4 or (kWidth == 8 and mfmaNonKDim == 32)) else 1 |
| 57 | + |
| 58 | + outType = 'i32' if dtypeA == 'i8' else 'f32' |
| 59 | + kWidth_a = kWidth_b = kWidth |
| 60 | + kGroup_a = kGroup_b = kGroup |
| 61 | + if isMixed864: |
| 62 | + if isType8BitFloat(dtypeA): |
| 63 | + kWidth_a = 16 |
| 64 | + kGroup_a = 2 |
| 65 | + kWidth_b = 32 |
| 66 | + kGroup_b = 1 |
| 67 | + else: |
| 68 | + kWidth_a = 32 |
| 69 | + kGroup_a = 1 |
| 70 | + kWidth_b = 16 |
| 71 | + kGroup_b = 2 |
| 72 | + kWidth_left = kWidth_b if trans else kWidth_a |
| 73 | + kGroup_left = kGroup_b if trans else kGroup_a |
| 74 | + |
| 75 | + elemSmall = 0.04 |
| 76 | + elemLarge = 0.16 |
| 77 | + elemPerThread = kWidth_a * kGroup_a |
| 78 | + if elemPerThread == 16: |
| 79 | + ratio = 0.8 |
| 80 | + elif elemPerThread == 32: |
| 81 | + ratio = 0.6 |
| 82 | + else: |
| 83 | + ratio = 1 |
| 84 | + elemWidth = elemLarge * ratio |
| 85 | + |
| 86 | + scaling = 1 if plot_scale else 0 |
| 87 | + |
| 88 | + return f"""\\begin{{document}} |
| 89 | + \\begin{{tikzpicture}} |
| 90 | + \\def\\scale{{1}} |
| 91 | + \\def\\elem{{{elemSmall}}} |
| 92 | + \\def\\elemW{{\\elem}} |
| 93 | + \\def\\kWidthA{{{kWidth_a}}} |
| 94 | + \\def\\kWidthB{{{kWidth_b}}} |
| 95 | + \\def\\kGroupA{{{kGroup_a}}} |
| 96 | + \\def\\kGroupB{{{kGroup_b}}} |
| 97 | + \\coordinate (C TL) at (0,0); |
| 98 | + \\drawDot{{{M}}}{{{N}}}{{{K}}}{{{mfmaNonKDim}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{trans}}} |
| 99 | +
|
| 100 | + \\coordinate (C TL) at ($(C TL)+({N}*\elem+32*\elem, 0)$); |
| 101 | + \\def\\mfmaTrans{{{trans}}} |
| 102 | +
|
| 103 | + %% Draw zoomed in view of mfma |
| 104 | + \\def\\scaleLabel{{{scaleLabel}}} |
| 105 | + \\pgfmathsetmacro{{\\oldElem}}{{\\elem}} |
| 106 | + \\def\\elem{{{elemLarge}}} |
| 107 | + \\def\\elemW{{{elemWidth}}} |
| 108 | + \\pgfmathsetmacro{{\\gap}}{{\\elem*5}} |
| 109 | + \\pgfmathsetmacro{{\\nonTrans}}{{1-\\mfmaTrans}} |
| 110 | + \\pgfmathsetmacro{{\\groups}}{{64/{mfmaNonKDim}}} |
| 111 | + \\coordinate (C TL) at ($(C TL)+({scaling}*0.3*\\gap+{scaling}*\\groups*4*\elemW+.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kWidth_left}*{kGroup_left}*\\elemW, -{M}*\\oldElem+{mfmaNonKDim}*\\elem)$); |
| 112 | + \\coordinate (mfma instr) at ($(C TL)+(-.5*\\gap-0.6*\\nonTrans*\\gap-0.4*\\mfmaTrans*\\gap, 1.5*\\gap+.5*\\mfmaTrans*\\gap)$); |
| 113 | + \\node [scale=\scaleLabel, above left, align=left, draw=black, fill=white] at (mfma instr) {{{mfma_inst_str}}}; |
| 114 | + \\drawMFMAInstr{{{mfmaNonKDim}}}{{\\mfmaTrans}}{{{dtypeA}}}{{{dtypeB}}}{{{outType}}}{{{scaling}}} |
| 115 | +
|
| 116 | + \\end{{tikzpicture}} |
| 117 | + \\end{{document}}""" |
| 118 | + |
| 119 | + |
| 120 | +def checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtypeA, dtypeB, trans, scale): |
| 121 | + # Check input types |
| 122 | + # Mixed precision is only allowed within f8, f6 and f4 |
| 123 | + assert (isMixedPrecType(dtypeA) and isMixedPrecType(dtypeB)) or \ |
| 124 | + (dtypeA == dtypeB), \ |
| 125 | + f"Cannot do mixed precision mfma with {dtypeA} and {dtypeB}" |
| 126 | + """ |
| 127 | + Check mfma size according to data types |
| 128 | + * refers to newly added instructions on gfx950 |
| 129 | + Both dtyes are f4 or fp6 or bf6 |
| 130 | + *mfma_f32_16x16x128_f8f6f4: kWidth = 32, kGroup = 1 |
| 131 | + *mfma_f32_32x32x64_f8f6f4: kWidth = 32, kGroup = 1 |
| 132 | + One dtype is fp8 or bf8 |
| 133 | + When the other operand is f4, fp6, or bf6 |
| 134 | + *mfma_f32_16x16x128_f8f6f4: kWidth = 16, kGroup = 2 |
| 135 | + *mfma_f32_32x32x64_f8f6f4: kWidth = 16, kGroup = 2 |
| 136 | + When the other operand is fp8 or bf8 |
| 137 | + *mfma_f32_16x16x128_f8f6f4: kWidth = 16, kGroup = 2 |
| 138 | + mfma_f32_16x16x32_fp8/bf8_fp8/bf8: kWidth = 16, kGroup = 1, kpack=2 |
| 139 | + mfma_f32_16x16x32_fp8/bf8_fp8/bf8: kWidth = 8, kGroup = 1 |
| 140 | + *mfma_f32_32x32x64_f8f6f4: kWidth = 16, kGroup = 2 |
| 141 | + mfma_f32_32x32x16_fp8/bf8_fp8/bf8: kWidth = 16, kGroup = 1, kpack=2 |
| 142 | + mfma_f32_32x32x16_fp8/bf8_fp8/bf8: kWidth = 8, kGroup = 1 |
| 143 | + Both dtypes are bf16 or bf16 |
| 144 | + *mfma_f32_16x16x32_f16/bf16: kWidth = 8, kGroup = 1 |
| 145 | + mfma_f32_16x16x16_f16/bf16: kWidth = 4, kGroup = 1 |
| 146 | + *mfma_f32_32x32x16_f16/bf16: kWidth = 8, kGroup = 1 |
| 147 | + mfma_f32_32x32x8_f16/bf16: kWidth = 4, kGroup = 1 |
| 148 | + Both types are i8 |
| 149 | + *mfma_i32_16x16x64_i8: kWidth = 16, kGroup = 1 |
| 150 | + mfma_i32_16x16x32_i8: kWidth = 8, kGroup = 1 |
| 151 | + *mfma_i32_32x32x32_i8: kWidth = 16, kGroup = 1 |
| 152 | + mfma_i32_32x32x16_i8: kWidth = 8, kGroup = 1 |
| 153 | +
|
| 154 | + Return mfma instruction name and kpack |
| 155 | + """ |
| 156 | + kDim = 64 / mfmaNonKDim * kWidth * kGroup |
| 157 | + # Both dtyes are f4 or fp6 or bf6 |
| 158 | + if isType4Or6Bit(dtypeA) and isType4Or6Bit(dtypeB): |
| 159 | + assert kWidth == 32 and kGroup == 1, f"Only kWidth=32 and kGroup=1 is supported for {dtypeA} x {dtypeB}" |
| 160 | + kpack = 1 |
| 161 | + CBSZ = matrixFormatTable[dtypeB] if trans else matrixFormatTable[dtypeA] |
| 162 | + BLGP = matrixFormatTable[dtypeA] if trans else matrixFormatTable[dtypeB] |
| 163 | + scale_str = 'scale_' if scale else '' |
| 164 | + return f"mfma_{scale_str}f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_f8f6f4", kpack, CBSZ, BLGP, scale |
| 165 | + |
| 166 | + # Both dtypes are fp8 or bf8 |
| 167 | + if isType8BitFloat(dtypeA) and isType8BitFloat(dtypeB): |
| 168 | + assert (kWidth == 8 and kGroup == 1) or ( |
| 169 | + kWidth == 16), f"Not a valid mfma instruction for {dtypeA} x {dtypeB} with {kWidth=} and {kGroup=}" |
| 170 | + kpack = 2 if (kWidth == 16 and kGroup == 1) else 1 |
| 171 | + if kGroup == 2: |
| 172 | + suffix = "f8f6f4" |
| 173 | + CBSZ = matrixFormatTable[dtypeB] if trans else matrixFormatTable[dtypeA] |
| 174 | + BLGP = matrixFormatTable[dtypeA] if trans else matrixFormatTable[dtypeB] |
| 175 | + plot_scale = scale |
| 176 | + scale_str = 'scale_' if scale else '' |
| 177 | + else: |
| 178 | + suffix = f"{dtypeB}_{dtypeA}" if trans else f"{dtypeA}_{dtypeB}" |
| 179 | + CBSZ = -1 |
| 180 | + BLGP = -1 |
| 181 | + plot_scale = False |
| 182 | + scale_str = '' |
| 183 | + kDim = kDim / 2 if kpack == 2 else kDim |
| 184 | + return f"mfma_{scale_str}f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{suffix}", kpack, CBSZ, BLGP, plot_scale |
| 185 | + |
| 186 | + # Both types are fp16 or bf16 |
| 187 | + if isType16Bit(dtypeA) and isType16Bit(dtypeB): |
| 188 | + assert (kWidth == 8 or kWidth == 4) and kGroup == 1, \ |
| 189 | + f"Not a valid mfma instruction for {dtypeA} x {dtypeB} with {kWidth=} and {kGroup=}" |
| 190 | + kpack = 1 |
| 191 | + CBSZ = -1 |
| 192 | + BLGP = -1 |
| 193 | + return f"mfma_f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{dtypeA}", kpack, CBSZ, BLGP, False |
| 194 | + |
| 195 | + # Both types are i8 |
| 196 | + if dtypeA == 'i8' and dtypeB == 'i8': |
| 197 | + assert (kWidth == 16 or kWidth == 8) and kGroup == 1, \ |
| 198 | + f"Not a valid mfma instruction for {dtypeA} x {dtypeB} with {kWidth=} and {kGroup=}" |
| 199 | + kpack = 1 |
| 200 | + CBSZ = -1 |
| 201 | + BLGP = -1 |
| 202 | + return f"mfma_i32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{dtypeA}", kpack, CBSZ, BLGP, False |
| 203 | + |
| 204 | + assert False, "Mixed precision between fp8/bf8 and fp6/bf6/f4 not supported in this mode" |
| 205 | + |
| 206 | + |
| 207 | +def generate_dot_tex(args): |
| 208 | + assert args.plot_type == "dot", \ |
| 209 | + f"parsing the wrong arguments. Want dot but have {args.plot_type}" |
| 210 | + # preprocess the args |
| 211 | + dotShape = args.dotShape |
| 212 | + M = dotShape[0] |
| 213 | + N = dotShape[1] |
| 214 | + K = dotShape[2] |
| 215 | + warpsPerCTA = args.warpsPerCTA |
| 216 | + mfmaNonKDim = args.nonKDim |
| 217 | + dtypeA = args.dtypeA |
| 218 | + dtypeB = args.dtypeB |
| 219 | + kWidth = args.kWidth |
| 220 | + kGroup = args.kGroup |
| 221 | + trans = args.mfmaTrans |
| 222 | + scale = args.scale |
| 223 | + # TODO: some of the checking can be done inside this dataclass as well but plot_dot requires quite some refactoring on this |
| 224 | + dotConfig = DotConfig(mfmaNonKDim, kWidth, kGroup, trans, warpsPerCTA) |
| 225 | + |
| 226 | + # checks and logging |
| 227 | + CTAShape = [ |
| 228 | + mfmaNonKDim * warpsPerCTA[0], |
| 229 | + mfmaNonKDim * warpsPerCTA[1], |
| 230 | + ] |
| 231 | + print(f"Plotting dot operation with shapes {(M, N, K)=}, {kWidth=}, {kGroup=}, {warpsPerCTA=}, {CTAShape=}") |
| 232 | + assert M != 0 and CTAShape[0] <= M and M % CTAShape[0] == 0 and \ |
| 233 | + N != 0 and CTAShape[1] <= N and N % CTAShape[1] == 0, \ |
| 234 | + f"block size ({M}, {N}) should equal to or be multiple of CTA shape ({CTAShape[0]}, {CTAShape[1]})" |
| 235 | + if isMixedPrecBtwF8AndF4OrF6(dtypeA, dtypeB): |
| 236 | + # In the case of mixed precision between 8-bit and 4 or 6-bit, |
| 237 | + # ignore kWidth and kGroup since inA and inB have different kWidth and kGroup values |
| 238 | + if mfmaNonKDim == 16: |
| 239 | + kDim = 128 |
| 240 | + elif mfmaNonKDim == 32: |
| 241 | + kDim = 64 |
| 242 | + else: |
| 243 | + raise NotImplementedError("scaled dot only supports 32x32x64 or 16x16x128 for now") |
| 244 | + assert K != 0 and K % kDim == 0, \ |
| 245 | + f"BLOCK_K = {K} should be spanned by one or multiple of MFMA instructions with KDim = {kDim}" |
| 246 | + kpack = 1 |
| 247 | + CBSZ = matrixFormatTable[dtypeB] if trans else matrixFormatTable[dtypeA] |
| 248 | + BLGP = matrixFormatTable[dtypeA] if trans else matrixFormatTable[dtypeB] |
| 249 | + scale_str = 'scale_' if scale else '' |
| 250 | + mfma_inst_str = f"mfma_{scale_str}f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_f8f6f4" |
| 251 | + isMixed864 = True |
| 252 | + plot_scale = scale |
| 253 | + else: |
| 254 | + kDim = kWidth * kGroup * 64 // mfmaNonKDim |
| 255 | + assert K % kDim == 0, f"one mfma instruction requires multiple of {kDim} elements along k dim but BLOCK_K = {K}" |
| 256 | + mfma_inst_str, kpack, CBSZ, BLGP, plot_scale = checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtypeA, dtypeB, |
| 257 | + trans, scale) |
| 258 | + isMixed864 = False |
| 259 | + flag = '' if CBSZ == -1 else f" with {CBSZ=},{BLGP=}" |
| 260 | + scale_info = " (scale is not supported hence ignored)" if (scale and not plot_scale) else '' |
| 261 | + print(f"MFMA: {mfma_inst_str} x {kpack}{flag}{scale_info}", end="") |
| 262 | + mfma_inst_str = mfma_inst_str.replace("_", "\\_") |
| 263 | + mfma_inst_str = mfma_inst_str + flag |
| 264 | + if kpack == 2: |
| 265 | + mfma_inst_str = mfma_inst_str + " $\\times$ 2" |
| 266 | + if ((dtypeA == 'fp16' or dtypeA == 'bf16') and kWidth == 8) or (dtypeA == 'i8' and kWidth == 16): |
| 267 | + kDim = 64 / mfmaNonKDim * kWidth / 2 |
| 268 | + outType = "i32" if dtypeA == 'i8' else "f32" |
| 269 | + old_instr = f"mfma_{outType}_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{dtypeA}" |
| 270 | + print(f" or {old_instr} x 2") |
| 271 | + old_instr = old_instr.replace("_", "\\_") |
| 272 | + mfma_inst_str = mfma_inst_str + " or\\\\" + old_instr + "$\\times$2" |
| 273 | + else: |
| 274 | + print("") |
| 275 | + |
| 276 | + # write the tex file |
| 277 | + curr_dir = Path(__file__).resolve().parent |
| 278 | + with open("myplot.tex", 'w') as f_plot: |
| 279 | + with open(curr_dir / "../utils/preamble.tex") as file: |
| 280 | + preamble = file.read() |
| 281 | + |
| 282 | + f_plot.write(preamble) |
| 283 | + draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, dtypeA, dtypeB, mfma_inst_str, isMixed864, plot_scale, |
| 284 | + dotConfig) |
| 285 | + func_ref = str(curr_dir / "dotLayout") |
| 286 | + f_plot.write(f"\input{{ {func_ref} }}\n") |
| 287 | + f_plot.write(draw_dotLayout_str) |
0 commit comments