@@ -5,113 +5,160 @@ Here is the help info from the script.
55
66``` bash
77> $ python3 plot_layout.py -h
8- usage: Draw triton layouts [-h] [-shape SHAPE SHAPE SHAPE] [-plot {blocked,dot,wmma,lds}] [-nonKDim {16,32}] [-sizePerThread SIZEPERTHREAD SIZEPERTHREAD] [-threadsPerWarp THREADSPERWARP THREADSPERWARP]
9- [-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-kWidth {4,8,16}] [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}] [-wave_size {32,64}] [-o O] [-mfmaTrans] [-keep]
8+ usage: Draw triton layouts [-h] [-tensorShape TENSORSHAPE TENSORSHAPE] [-dotShape DOTSHAPE DOTSHAPE DOTSHAPE] [-plot {blocked,dot,wmma,lds}] [-dim0 DIM0] [-dim1 DIM1] [-sizePerThread SIZEPERTHREAD SIZEPERTHREAD]
9+ [-threadsPerWarp THREADSPERWARP THREADSPERWARP] [-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-nonKDim {16,32}] [-kWidth {4,8,16,32}] [-kGroup {1,2}]
10+ [-dtype_a {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}] [-dtype_b {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}] [-mfmaTrans] [-scale] [-banks {32,64}] [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}]
11+ [-mnContig] [-mfma_trans_load] [-swizzleVec {4,8,16,32}] [-padInterval PADINTERVAL] [-padAmount PADAMOUNT] [-wave_size {32,64}] [-o O] [-keep]
1012
1113options:
1214 -h, --help show this help message and exit
13- -shape SHAPE SHAPE SHAPE
14- Tensor shape in the form of M,N,K
15+ -tensorShape TENSORSHAPE TENSORSHAPE
16+ 2D tensor shape in the form of dim0,dim1
17+ -dotShape DOTSHAPE DOTSHAPE DOTSHAPE
18+ Dot op shape in the form of M,N,K
1519 -plot {blocked,dot,wmma,lds}
1620 choose plot mode
17- -nonKDim {16,32} mfma instruction dim
21+ -dim0 DIM0 tensor dim0 name
22+ -dim1 DIM1 tensor dim1 name
1823 -sizePerThread SIZEPERTHREAD SIZEPERTHREAD
1924 -threadsPerWarp THREADSPERWARP THREADSPERWARP
2025 -warpsPerCTA WARPSPERCTA WARPSPERCTA
2126 -order ORDER ORDER
22- -kWidth {4,8,16} number of elements per thread
27+ -nonKDim {16,32} mfma instruction dim
28+ -kWidth {4,8,16,32} number of contiguous elements per thread
29+ -kGroup {1,2} total number of elements / kWidth per mfma instruction
30+ -dtype_a {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}
31+ element type of operand A
32+ -dtype_b {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}
33+ element type of operand B
34+ -mfmaTrans If set, then use mfma.trans layout
35+ -scale If set, plot the scale tensor for mfma_f8f6f4 instructions
36+ -banks {32,64} choose the number of banks in LDS
2337 -lds_layout {swizzle,padding,none}
2438 choose the LDS data layout
2539 -lds_access {read,write,none}
2640 choose LDS access mode
41+ -mnContig If set, the tensor is K x N and n-contig
42+ -mfma_trans_load If set, use MFMA transpose load instructions
43+ -swizzleVec {4,8,16,32}
44+ number of contiguous elements in a vector to swizzle
45+ -padInterval PADINTERVAL
46+ Add padding for every padInterval bytes
47+ -padAmount PADAMOUNT Pad padAmount bytes for every padInterval bytes
2748 -wave_size {32,64} choose the wmma instruction mode
2849 -o O output pdf file name (without surfix)
29- -mfmaTrans If set, then use mfma.trans layout
3050 -keep If set, keep the generated .tex file
3151```
3252
3353## Installation
3454This script does not require torch or triton to be installed. The only package
3555it depends on is latex. On Ubuntu, do
3656``` bash
37- sudo apt install texlive-full
57+ sudo apt-get install texlive-latex-base texlive-latex-extra texlive-fonts-recommended texlive-fonts-extra
58+
3859```
3960
4061## Draw blocked layout (` -plot blocked ` )
4162
4263Examples:
4364``` bash
44- python3 plot_layout.py -plot blocked -shape 128 128 64 -sizePerThread 1 8 -threadsPerWarp 8 8 -warpsPerCTA 4 1
45- python3 plot_layout.py -plot blocked -shape 16 128 64 -sizePerThread 1 8 -threadsPerWarp 16 4 -warpsPerCTA 1 2
46- python3 plot_layout.py -plot blocked -shape 32 128 64 -sizePerThread 8 1 -threadsPerWarp 4 16 -warpsPerCTA 1 2 -order 0 1
65+ python3 plot_layout.py -plot blocked -tensorShape 128 64 -sizePerThread 1 8 -threadsPerWarp 8 8 -warpsPerCTA 4 1
66+ python3 plot_layout.py -plot blocked -tensorShape 16 64 -sizePerThread 1 8 -threadsPerWarp 16 4 -warpsPerCTA 1 2
67+ python3 plot_layout.py -plot blocked -tensorShape 32 64 -sizePerThread 8 1 -threadsPerWarp 4 16 -warpsPerCTA 1 2 -order 0 1
4768```
4869
4970Blocked layouts are used during global load. It is used to describe the layout of the tensor
5071for pointers and results.
51- We can provide tensor shape (` -shape M N K ` ) and blocked layout parameters (
72+ We can provide tensor shape (` -tensorShape dim0 dim1 ` ) and blocked layout parameters (
5273` -sizePerThread x y ` , ` -threadsPerWarp x y ` , and ` -warpsPerCTA x y ` ).
5374We can also provide the order of the tensor as ` -order x y ` to control which dim
5475is the fastest changing dimension.
5576
5677Notes
57- - All of the gemm dims (M, N, and K) are needed when providing the shape. But only
58- M and K will be used to plot the layout of the tensor.
5978- The script does not support the case when threads are loading elements that are
6079 out of the boundary of the tensor dimensions. This means
61- - For M : sizePerThread[ 0] * threadsPerWarps[ 0] * warpsPerCTA[ 0] <= M
62- - For K : sizePerThread[ 1] * threadsPerWarps[ 1] * warpsPerCTA[ 1] <= K
80+ - For dim0 : sizePerThread[ 0] * threadsPerWarps[ 0] * warpsPerCTA[ 0] <= dim0
81+ - For dim1 : sizePerThread[ 1] * threadsPerWarps[ 1] * warpsPerCTA[ 1] <= dim1
6382
6483
6584## Draw mfma operand and result layouts (` -plot dot ` )
6685
6786Examples:
6887``` bash
69- python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 4
70- python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8
71- python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8 -mfmaTrans
72- python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 8
73- python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 16
88+ # # i8 inputs
89+ python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 8 -dtype_a i8 -dtype_b i8
90+ python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -dtype_a i8 -dtype_b i8
91+ # # fp16/bf16 inputs
92+ python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 4 -dtype_a fp16 -dtype_b fp16
93+ python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 8 -dtype_a fp16 -dtype_b fp16
94+ # # fp8/bf8 inputs
95+ python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 8 -dtype_a fp8 -dtype_b bf8
96+ python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -dtype_a fp8 -dtype_b bf8
97+ python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -kGroup 2 -dtype_a fp8 -dtype_b bf8
98+ # # f4 and fp6/bf6 inputs
99+ python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 32 -kGroup 1 -dtype_a f4 -dtype_b bf6
100+ # # fp8/bf8 and fp6/bf6/f4 inputs
101+ python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -kGroup 2 -dtype_a fp6 -dtype_b bf8
102+ # # mixed precision with scaling
103+ python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -kGroup 2 -dtype_a fp6 -dtype_b bf8 -scale
74104```
75105
106+ One can add ` -nonKDim [16,32] ` and ` -mfmaTrans ` to all of the above examples.
107+
76108This mode draws two graphs:
77- 1 . The layout of the whole tile for tile A, B, and C
109+ 1 . The layout of the dot operation, i.e. tile C = tile A x tile B
781102 . The layout of a single mfma block, operands and results of one or more mfma
79111 instructions that share the same accumulating VGPRs.
80- This view has thread distributions among tensor elements.
81112
82113Knobs
83- - ` -kWidth ` : the number of elements that will be loaded into one thread at once
84- - ` -nonKDim ` : 16 ot 32, which is used to control the mfma instruction size
114+ - ` -kWidth [4,8,16,32] ` : the number of elements that will be loaded into one thread at once
115+ - ` -kGroup [1,2] ` : total number of elements / kWidth for on mfma instruction.
116+ This is 1 for all mfma instructions except for mfma_f32_16x16x128_f8f6f4 and mfma_f32_32x32x64_f8f6f4
117+ with fp8 input types (CBSZ=0 or 1 and/or BLGP=0 or 1)
118+ - ` -nonKDim [16,32] ` : mfma instruction size. The default is set to 16.
85119- ` -mfmaTrans ` : if set, the transposed mfma layout will be plotted.
120+ - ` -dtype_a ` and ` -dtype_b ` : element types of operand A and B. The default value is fp16.
121+ - ` -scale ` : plot scale tensors for A and B. This is only supported with f4/f6 and f8 with ` kGroup=2 ` .
122+ If ` -scale ` is set but not supported, it's ignored.
86123
87124Notes
88125- The layout shows the mapping from the threads/wave to the elements in the
89- original tensor. It does not care if the elements are arranged in LDS, like
90- swizzling to avoid bank conflicts.
91- - The script does not allow settings for data type or k dim of the mfma instruction.
92- This can be controled by the ` -kWidth ` flag.
93- - For example, if we want ` mfma_32x32x8xf16 ` , we can set ` -nonKDim 32 ` and ` -kWidth 4 ` .
94- - If we want ` mfma_32x32x16xf8 ` , we can set ` -nonKDim 32 ` and ` -kWidth 8 ` .
95-
126+ original tensor. It does not matter if LDS is used.
127+ - The script does not allow settings for k dim of the mfma instruction.
128+ This can be controled by the ` -kWidth ` and ` -kGroup ` .
96129
97130## Draw LDS access (` -plot lds ` )
98131
99132Examples:
100133``` bash
101- python3 plot_layout.py -plot lds -lds_layout none -lds_access none -shape 128 128 64 -kWidth 8
134+ python3 plot_layout.py -plot lds -lds_layout none -lds_access none -tensorShape 128 128 -kWidth 8
135+ python3 plot_layout.py -plot lds -lds_layout none -lds_access none -tensorShape 128 128 -kWidth 32 -dtype_a f4
136+ python3 plot_layout.py -plot lds -lds_layout none -lds_access none -tensorShape 128 128 -kWidth 16 -dtype_a fp8 -banks 64
137+ python3 plot_layout.py -plot lds -lds_layout swizzle -lds_access none -tensorShape 128 128 -kWidth 16 -dtype_a fp8 -banks 64
138+ python3 plot_layout.py -plot lds -lds_layout swizzle -lds_access read -tensorShape 128 128 -kWidth 16 -dtype_a bf8 -banks 64
139+ python3 plot_layout.py -plot lds -lds_layout swizzle -lds_access write -tensorShape 128 128 -kWidth 16 -dtype_a f4 -banks 32
140+ python3 plot_layout.py -plot lds -lds_layout none -lds_access read -tensorShape 128 32 -kWidth 4 -dtype_a fp16 -banks 64 -mnContig
141+ python3 plot_layout.py -plot lds -lds_layout swizzle -lds_access read -tensorShape 128 32 -kWidth 16 -dtype_a fp8 -banks 64 -mnContig -mfma_trans_load
142+ python3 plot_layout.py -plot lds -lds_layout padding -lds_access none -tensorShape 128 32 -kWidth 8 -dtype_a fp16 -banks 32 -padInterval 128 -padAmount 16
102143```
103144
104145Knobs
105- - ` kWidth ` here means the vector size when accessing LDS
146+ - ` kWidth ` : the vector size (in unit of elements) when accessing LDS
147+ - ` banks ` : the number of banks in LDS. (64 for gfx950, 32 for pre-gfx950)
148+ - ` dtype_a ` : element data type
106149- Three options for ` -lds_layout ` :
107150 - ` none ` : no swizzling, no padding
108- - ` padding ` : padding at every 128B
109- - ` swizzling ` : apply the swizzling pattern, which is derived from tensor shape and kWidth.
151+ - ` swizzle ` : apply the swizzling pattern, which is derived from tensor shape and kWidth.
152+ - ` padding ` : pad ` padAmount ` bytes for every ` padInterval ` bytes of data
153+ - ` padAmount ` : default is 0
154+ - ` padInterval ` : default is 1
110155- Three options for ` -lds_access ` :
111156 - ` none ` : do not plot access pattern
112- - ` read ` : plot accessed elements during ds_read
113- - ` write ` : plot accessed elements during ds_write. Note that this needs some infomation from
114- global load. Therefore, we need to provide ` -sizePerThread ` and ` -threadsPerWarp ` .
115-
116- Notes
117- - This mode is rarely used. If you have any questions, please contact Lixun Zhang directly.
157+ - ` read ` : plot accessed elements at the first cycle of ds_read
158+ - ` write ` : plot accessed elements during ds_write. For global load access, we assume
159+ a fully coalesced dwordx4 access pattern along the K dim.
160+ - ` mnContig ` : If set, the tile is stored in mn-contig layout. In this layout, elements along
161+ the M/N dim are contiguous in both global memory and LDS.
162+ - ` mfma_trans_load ` : This flag only works when ` mnContig ` is set. When set, ` ds_read_b64_tr_bx `
163+ instructions are used to read from LDS. Note that current triton LDS layout mechanism will
164+ lead to bank conflicts.
0 commit comments