Skip to content

Commit 32a8fdd

Browse files
jtang10Jingning Tang
andauthored
plot_layout.py Refactoring (#775)
* refactored plot_layout.py * disabled all abbreviation * fixed relative path issue --------- Co-authored-by: Jingning Tang <[email protected]>
1 parent a18fc87 commit 32a8fdd

File tree

17 files changed

+913
-690
lines changed

17 files changed

+913
-690
lines changed

python/perf-kernels/tools/plot-layout/README.md

Lines changed: 112 additions & 81 deletions
Large diffs are not rendered by default.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .plot_blocked import generate_blocked_tex
2+
3+
__all__ = ["generate_blocked_tex"]

python/perf-kernels/tools/plot-layout/blockedLayout.tex renamed to python/perf-kernels/tools/plot-layout/blocked/blockedLayout.tex

File renamed without changes.
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from dataclasses import dataclass
2+
from pathlib import Path
3+
4+
5+
@dataclass
6+
class BlockedConfig:
7+
sizePerThread: tuple
8+
threadsPerWarp: tuple
9+
warpsPerCTA: tuple
10+
order: tuple
11+
12+
13+
def draw_blocked_layout_cmd(dim0, dim1, dim0Name, dim1Name, blockedConfig):
14+
return f"""\\begin{{document}}
15+
\\begin{{tikzpicture}}
16+
\\def\\scale{{1}}
17+
\\def\\elem{{0.06}}
18+
\\coordinate (TL) at (0,0);
19+
\\def\\dimColName{{{dim0Name}}}
20+
\\def\\dimRowName{{{dim1Name}}}
21+
\\drawBlockedTensor{{{dim0}}}{{{dim1}}}{{{blockedConfig.sizePerThread[0]}}}{{{blockedConfig.sizePerThread[1]}}}{{{blockedConfig.threadsPerWarp[0]}}}{{{blockedConfig.warpsPerCTA[0]}}}{{{blockedConfig.warpsPerCTA[1]}}}{{{blockedConfig.order[0]}}}
22+
\\end{{tikzpicture}}
23+
\\end{{document}}"""
24+
25+
26+
def generate_blocked_tex(args):
27+
"""Generate the tex file of blocked layout and draw it out"""
28+
assert args.plot_type == "blocked", \
29+
f"parsing the wrong arguments. Want blocked but have {args.plot_type}"
30+
# preprocess the args
31+
# shortcut to plot dot operand B to save some cmd args
32+
if args.matrixB:
33+
dim0Name, dim1Name = "K", "N"
34+
else:
35+
dim0Name, dim1Name = args.rowName, args.colName
36+
# TODO: this can be further refactored to absorb the assertions below to make it more elegant
37+
sizePerThread = args.sizePerThread
38+
threadsPerWarp = args.threadsPerWarp
39+
warpsPerCTA = args.warpsPerCTA
40+
order = args.order
41+
blockedConfig = BlockedConfig(sizePerThread, threadsPerWarp, warpsPerCTA, order)
42+
CTAShape = [
43+
sizePerThread[0] * threadsPerWarp[0] * warpsPerCTA[0],
44+
sizePerThread[1] * threadsPerWarp[1] * warpsPerCTA[1],
45+
]
46+
47+
# checks and logging
48+
if args.blockShape is not None:
49+
dim0, dim1 = args.blockShape
50+
else:
51+
print(f"Since block size is not explicitly defined, it assumes block size = CTAShape = {CTAShape}")
52+
dim0, dim1 = CTAShape
53+
print(f"Plotting a block [{dim0Name}, {dim1Name}] = [{dim0}, {dim1}] with the following blocked layout:")
54+
print(f"{sizePerThread=}", end=", ")
55+
print(f"{threadsPerWarp=}", end=", ")
56+
print(f"{warpsPerCTA=}", end=", ")
57+
print(f"{order=}", end=", ")
58+
print(f"CTAShape={CTAShape}")
59+
assert dim0 != 0 and CTAShape[0] <= dim0 and dim0 % CTAShape[0] == 0, \
60+
"CTAShape[0] should be smaller than dim of {dim0Name}={dim0} and fully spans it"
61+
assert dim1 != 0 and CTAShape[1] <= dim1 and dim1 % CTAShape[1] == 0, \
62+
"CTAShape[1] should be smaller than dim of {dim1Name}={dim1} and fully spans it"
63+
64+
# write the tex file
65+
curr_dir = Path(__file__).resolve().parent
66+
with open("myplot.tex", 'w') as f_plot:
67+
with open(curr_dir / "../utils/preamble.tex") as file:
68+
preamble = file.read()
69+
70+
f_plot.write(preamble)
71+
draw_blockedLayout_str = draw_blocked_layout_cmd(dim0, dim1, dim0Name, dim1Name, blockedConfig)
72+
func_ref = str(curr_dir / "blockedLayout")
73+
f_plot.write(f"\input{{ {func_ref} }}\n")
74+
f_plot.write(draw_blockedLayout_str)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .plot_dot import generate_dot_tex
2+
3+
__all__ = ["generate_dot_tex"]
File renamed without changes.
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
from dataclasses import dataclass
2+
from pathlib import Path
3+
4+
5+
@dataclass
6+
class DotConfig:
7+
mfmaNonKDim: int
8+
kWidth: int
9+
kGroup: int
10+
trans: int
11+
warpsPerCTA: tuple
12+
13+
14+
matrixFormatTable = {'fp8': 0, 'bf8': 1, 'fp6': 2, 'bf6': 3, 'f4': 4}
15+
16+
17+
def matrixFormat(dtypeA, dtypeB):
18+
"""
19+
return CBSZ and BLGP according to data types
20+
b000: E4M3(FP8)
21+
b001: E5M2(BF8)
22+
b010: E2M3(FP6)
23+
b011: E3M2(BF6)
24+
b100: E2M1(FP4)
25+
"""
26+
return matrixFormatTable[dtypeA], matrixFormatTable[dtypeB]
27+
28+
29+
def isType4Or6Bit(dtype):
30+
return dtype == 'fp6' or dtype == 'bf6' or dtype == 'f4'
31+
32+
33+
def isType8BitFloat(dtype):
34+
return dtype == 'fp8' or dtype == 'bf8'
35+
36+
37+
def isType16Bit(dtype):
38+
return dtype == 'bf16' or dtype == 'fp16'
39+
40+
41+
def isMixedPrecType(dtype):
42+
return isType8BitFloat(dtype) or isType4Or6Bit(dtype)
43+
44+
45+
def isMixedPrecBtwF8AndF4OrF6(dtypeA, dtypeB):
46+
return (isType8BitFloat(dtypeA) and isType4Or6Bit(dtypeB)) or \
47+
(isType8BitFloat(dtypeB) and isType4Or6Bit(dtypeA))
48+
49+
50+
def draw_dot_layout_cmd(M, N, K, dtypeA, dtypeB, mfma_inst_str, isMixed864, plot_scale, dotConfig):
51+
mfmaNonKDim = dotConfig.mfmaNonKDim
52+
warpsPerCTA = dotConfig.warpsPerCTA
53+
trans = 1 if dotConfig.trans else 0
54+
kWidth = dotConfig.kWidth
55+
kGroup = dotConfig.kGroup
56+
scaleLabel = 0.7 if (kWidth == 4 or (kWidth == 8 and mfmaNonKDim == 32)) else 1
57+
58+
outType = 'i32' if dtypeA == 'i8' else 'f32'
59+
kWidth_a = kWidth_b = kWidth
60+
kGroup_a = kGroup_b = kGroup
61+
if isMixed864:
62+
if isType8BitFloat(dtypeA):
63+
kWidth_a = 16
64+
kGroup_a = 2
65+
kWidth_b = 32
66+
kGroup_b = 1
67+
else:
68+
kWidth_a = 32
69+
kGroup_a = 1
70+
kWidth_b = 16
71+
kGroup_b = 2
72+
kWidth_left = kWidth_b if trans else kWidth_a
73+
kGroup_left = kGroup_b if trans else kGroup_a
74+
75+
elemSmall = 0.04
76+
elemLarge = 0.16
77+
elemPerThread = kWidth_a * kGroup_a
78+
if elemPerThread == 16:
79+
ratio = 0.8
80+
elif elemPerThread == 32:
81+
ratio = 0.6
82+
else:
83+
ratio = 1
84+
elemWidth = elemLarge * ratio
85+
86+
scaling = 1 if plot_scale else 0
87+
88+
return f"""\\begin{{document}}
89+
\\begin{{tikzpicture}}
90+
\\def\\scale{{1}}
91+
\\def\\elem{{{elemSmall}}}
92+
\\def\\elemW{{\\elem}}
93+
\\def\\kWidthA{{{kWidth_a}}}
94+
\\def\\kWidthB{{{kWidth_b}}}
95+
\\def\\kGroupA{{{kGroup_a}}}
96+
\\def\\kGroupB{{{kGroup_b}}}
97+
\\coordinate (C TL) at (0,0);
98+
\\drawDot{{{M}}}{{{N}}}{{{K}}}{{{mfmaNonKDim}}}{{{warpsPerCTA[0]}}}{{{warpsPerCTA[1]}}}{{{trans}}}
99+
100+
\\coordinate (C TL) at ($(C TL)+({N}*\elem+32*\elem, 0)$);
101+
\\def\\mfmaTrans{{{trans}}}
102+
103+
%% Draw zoomed in view of mfma
104+
\\def\\scaleLabel{{{scaleLabel}}}
105+
\\pgfmathsetmacro{{\\oldElem}}{{\\elem}}
106+
\\def\\elem{{{elemLarge}}}
107+
\\def\\elemW{{{elemWidth}}}
108+
\\pgfmathsetmacro{{\\gap}}{{\\elem*5}}
109+
\\pgfmathsetmacro{{\\nonTrans}}{{1-\\mfmaTrans}}
110+
\\pgfmathsetmacro{{\\groups}}{{64/{mfmaNonKDim}}}
111+
\\coordinate (C TL) at ($(C TL)+({scaling}*0.3*\\gap+{scaling}*\\groups*4*\elemW+.5*\\gap+1.2*\\nonTrans*\\gap+\\groups*{kWidth_left}*{kGroup_left}*\\elemW, -{M}*\\oldElem+{mfmaNonKDim}*\\elem)$);
112+
\\coordinate (mfma instr) at ($(C TL)+(-.5*\\gap-0.6*\\nonTrans*\\gap-0.4*\\mfmaTrans*\\gap, 1.5*\\gap+.5*\\mfmaTrans*\\gap)$);
113+
\\node [scale=\scaleLabel, above left, align=left, draw=black, fill=white] at (mfma instr) {{{mfma_inst_str}}};
114+
\\drawMFMAInstr{{{mfmaNonKDim}}}{{\\mfmaTrans}}{{{dtypeA}}}{{{dtypeB}}}{{{outType}}}{{{scaling}}}
115+
116+
\\end{{tikzpicture}}
117+
\\end{{document}}"""
118+
119+
120+
def checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtypeA, dtypeB, trans, scale):
121+
# Check input types
122+
# Mixed precision is only allowed within f8, f6 and f4
123+
assert (isMixedPrecType(dtypeA) and isMixedPrecType(dtypeB)) or \
124+
(dtypeA == dtypeB), \
125+
f"Cannot do mixed precision mfma with {dtypeA} and {dtypeB}"
126+
"""
127+
Check mfma size according to data types
128+
* refers to newly added instructions on gfx950
129+
Both dtyes are f4 or fp6 or bf6
130+
*mfma_f32_16x16x128_f8f6f4: kWidth = 32, kGroup = 1
131+
*mfma_f32_32x32x64_f8f6f4: kWidth = 32, kGroup = 1
132+
One dtype is fp8 or bf8
133+
When the other operand is f4, fp6, or bf6
134+
*mfma_f32_16x16x128_f8f6f4: kWidth = 16, kGroup = 2
135+
*mfma_f32_32x32x64_f8f6f4: kWidth = 16, kGroup = 2
136+
When the other operand is fp8 or bf8
137+
*mfma_f32_16x16x128_f8f6f4: kWidth = 16, kGroup = 2
138+
mfma_f32_16x16x32_fp8/bf8_fp8/bf8: kWidth = 16, kGroup = 1, kpack=2
139+
mfma_f32_16x16x32_fp8/bf8_fp8/bf8: kWidth = 8, kGroup = 1
140+
*mfma_f32_32x32x64_f8f6f4: kWidth = 16, kGroup = 2
141+
mfma_f32_32x32x16_fp8/bf8_fp8/bf8: kWidth = 16, kGroup = 1, kpack=2
142+
mfma_f32_32x32x16_fp8/bf8_fp8/bf8: kWidth = 8, kGroup = 1
143+
Both dtypes are bf16 or bf16
144+
*mfma_f32_16x16x32_f16/bf16: kWidth = 8, kGroup = 1
145+
mfma_f32_16x16x16_f16/bf16: kWidth = 4, kGroup = 1
146+
*mfma_f32_32x32x16_f16/bf16: kWidth = 8, kGroup = 1
147+
mfma_f32_32x32x8_f16/bf16: kWidth = 4, kGroup = 1
148+
Both types are i8
149+
*mfma_i32_16x16x64_i8: kWidth = 16, kGroup = 1
150+
mfma_i32_16x16x32_i8: kWidth = 8, kGroup = 1
151+
*mfma_i32_32x32x32_i8: kWidth = 16, kGroup = 1
152+
mfma_i32_32x32x16_i8: kWidth = 8, kGroup = 1
153+
154+
Return mfma instruction name and kpack
155+
"""
156+
kDim = 64 / mfmaNonKDim * kWidth * kGroup
157+
# Both dtyes are f4 or fp6 or bf6
158+
if isType4Or6Bit(dtypeA) and isType4Or6Bit(dtypeB):
159+
assert kWidth == 32 and kGroup == 1, f"Only kWidth=32 and kGroup=1 is supported for {dtypeA} x {dtypeB}"
160+
kpack = 1
161+
CBSZ = matrixFormatTable[dtypeB] if trans else matrixFormatTable[dtypeA]
162+
BLGP = matrixFormatTable[dtypeA] if trans else matrixFormatTable[dtypeB]
163+
scale_str = 'scale_' if scale else ''
164+
return f"mfma_{scale_str}f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_f8f6f4", kpack, CBSZ, BLGP, scale
165+
166+
# Both dtypes are fp8 or bf8
167+
if isType8BitFloat(dtypeA) and isType8BitFloat(dtypeB):
168+
assert (kWidth == 8 and kGroup == 1) or (
169+
kWidth == 16), f"Not a valid mfma instruction for {dtypeA} x {dtypeB} with {kWidth=} and {kGroup=}"
170+
kpack = 2 if (kWidth == 16 and kGroup == 1) else 1
171+
if kGroup == 2:
172+
suffix = "f8f6f4"
173+
CBSZ = matrixFormatTable[dtypeB] if trans else matrixFormatTable[dtypeA]
174+
BLGP = matrixFormatTable[dtypeA] if trans else matrixFormatTable[dtypeB]
175+
plot_scale = scale
176+
scale_str = 'scale_' if scale else ''
177+
else:
178+
suffix = f"{dtypeB}_{dtypeA}" if trans else f"{dtypeA}_{dtypeB}"
179+
CBSZ = -1
180+
BLGP = -1
181+
plot_scale = False
182+
scale_str = ''
183+
kDim = kDim / 2 if kpack == 2 else kDim
184+
return f"mfma_{scale_str}f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{suffix}", kpack, CBSZ, BLGP, plot_scale
185+
186+
# Both types are fp16 or bf16
187+
if isType16Bit(dtypeA) and isType16Bit(dtypeB):
188+
assert (kWidth == 8 or kWidth == 4) and kGroup == 1, \
189+
f"Not a valid mfma instruction for {dtypeA} x {dtypeB} with {kWidth=} and {kGroup=}"
190+
kpack = 1
191+
CBSZ = -1
192+
BLGP = -1
193+
return f"mfma_f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{dtypeA}", kpack, CBSZ, BLGP, False
194+
195+
# Both types are i8
196+
if dtypeA == 'i8' and dtypeB == 'i8':
197+
assert (kWidth == 16 or kWidth == 8) and kGroup == 1, \
198+
f"Not a valid mfma instruction for {dtypeA} x {dtypeB} with {kWidth=} and {kGroup=}"
199+
kpack = 1
200+
CBSZ = -1
201+
BLGP = -1
202+
return f"mfma_i32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{dtypeA}", kpack, CBSZ, BLGP, False
203+
204+
assert False, "Mixed precision between fp8/bf8 and fp6/bf6/f4 not supported in this mode"
205+
206+
207+
def generate_dot_tex(args):
208+
assert args.plot_type == "dot", \
209+
f"parsing the wrong arguments. Want dot but have {args.plot_type}"
210+
# preprocess the args
211+
dotShape = args.dotShape
212+
M = dotShape[0]
213+
N = dotShape[1]
214+
K = dotShape[2]
215+
warpsPerCTA = args.warpsPerCTA
216+
mfmaNonKDim = args.nonKDim
217+
dtypeA = args.dtypeA
218+
dtypeB = args.dtypeB
219+
kWidth = args.kWidth
220+
kGroup = args.kGroup
221+
trans = args.mfmaTrans
222+
scale = args.scale
223+
# TODO: some of the checking can be done inside this dataclass as well but plot_dot requires quite some refactoring on this
224+
dotConfig = DotConfig(mfmaNonKDim, kWidth, kGroup, trans, warpsPerCTA)
225+
226+
# checks and logging
227+
CTAShape = [
228+
mfmaNonKDim * warpsPerCTA[0],
229+
mfmaNonKDim * warpsPerCTA[1],
230+
]
231+
print(f"Plotting dot operation with shapes {(M, N, K)=}, {kWidth=}, {kGroup=}, {warpsPerCTA=}, {CTAShape=}")
232+
assert M != 0 and CTAShape[0] <= M and M % CTAShape[0] == 0 and \
233+
N != 0 and CTAShape[1] <= N and N % CTAShape[1] == 0, \
234+
f"block size ({M}, {N}) should equal to or be multiple of CTA shape ({CTAShape[0]}, {CTAShape[1]})"
235+
if isMixedPrecBtwF8AndF4OrF6(dtypeA, dtypeB):
236+
# In the case of mixed precision between 8-bit and 4 or 6-bit,
237+
# ignore kWidth and kGroup since inA and inB have different kWidth and kGroup values
238+
if mfmaNonKDim == 16:
239+
kDim = 128
240+
elif mfmaNonKDim == 32:
241+
kDim = 64
242+
else:
243+
raise NotImplementedError("scaled dot only supports 32x32x64 or 16x16x128 for now")
244+
assert K != 0 and K % kDim == 0, \
245+
f"BLOCK_K = {K} should be spanned by one or multiple of MFMA instructions with KDim = {kDim}"
246+
kpack = 1
247+
CBSZ = matrixFormatTable[dtypeB] if trans else matrixFormatTable[dtypeA]
248+
BLGP = matrixFormatTable[dtypeA] if trans else matrixFormatTable[dtypeB]
249+
scale_str = 'scale_' if scale else ''
250+
mfma_inst_str = f"mfma_{scale_str}f32_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_f8f6f4"
251+
isMixed864 = True
252+
plot_scale = scale
253+
else:
254+
kDim = kWidth * kGroup * 64 // mfmaNonKDim
255+
assert K % kDim == 0, f"one mfma instruction requires multiple of {kDim} elements along k dim but BLOCK_K = {K}"
256+
mfma_inst_str, kpack, CBSZ, BLGP, plot_scale = checkMfmaValidity(mfmaNonKDim, kWidth, kGroup, dtypeA, dtypeB,
257+
trans, scale)
258+
isMixed864 = False
259+
flag = '' if CBSZ == -1 else f" with {CBSZ=},{BLGP=}"
260+
scale_info = " (scale is not supported hence ignored)" if (scale and not plot_scale) else ''
261+
print(f"MFMA: {mfma_inst_str} x {kpack}{flag}{scale_info}", end="")
262+
mfma_inst_str = mfma_inst_str.replace("_", "\\_")
263+
mfma_inst_str = mfma_inst_str + flag
264+
if kpack == 2:
265+
mfma_inst_str = mfma_inst_str + " $\\times$ 2"
266+
if ((dtypeA == 'fp16' or dtypeA == 'bf16') and kWidth == 8) or (dtypeA == 'i8' and kWidth == 16):
267+
kDim = 64 / mfmaNonKDim * kWidth / 2
268+
outType = "i32" if dtypeA == 'i8' else "f32"
269+
old_instr = f"mfma_{outType}_{mfmaNonKDim}x{mfmaNonKDim}x{kDim:.0f}_{dtypeA}"
270+
print(f" or {old_instr} x 2")
271+
old_instr = old_instr.replace("_", "\\_")
272+
mfma_inst_str = mfma_inst_str + " or\\\\" + old_instr + "$\\times$2"
273+
else:
274+
print("")
275+
276+
# write the tex file
277+
curr_dir = Path(__file__).resolve().parent
278+
with open("myplot.tex", 'w') as f_plot:
279+
with open(curr_dir / "../utils/preamble.tex") as file:
280+
preamble = file.read()
281+
282+
f_plot.write(preamble)
283+
draw_dotLayout_str = draw_dot_layout_cmd(M, N, K, dtypeA, dtypeB, mfma_inst_str, isMixed864, plot_scale,
284+
dotConfig)
285+
func_ref = str(curr_dir / "dotLayout")
286+
f_plot.write(f"\input{{ {func_ref} }}\n")
287+
f_plot.write(draw_dotLayout_str)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .plot_lds import generate_lds_tex
2+
3+
__all__ = ["generate_lds_tex"]
File renamed without changes.

0 commit comments

Comments
 (0)