Skip to content

Commit 7613c4d

Browse files
authored
Support gfx950 layouts (#692)
* Move preamble code into tikzplot.tex * Rename kpack to kWidth and allow kWidth = 32 * [API change] Take user input to set dim names API change: - For blocked layout, use -tensorShape, which only takes two dims as dim0,dim1 - For dot layout, use -dotShape, which takes three dims as M,N,K * Re-structure files Separate each layout's code into their own files * Extend dotLayout plot to support kWidth=32 - When kWidth is large, use a smaller elemSize honrizontally to save space - Improve the labels, such as - change vec to kWidth for operands - change opA/opB to inA/inB and include operand dims - remove group dims in the operands so that they don't overlap with operand block dims - Better alignment: dot op and mfma zoomed-in pics are bottom aligned * [API change] Add support for kGroup kGroup is defined as total elements per thread / kWidth for one mfma instruction. We need kGroup = 2 only for the newly added mfma_f32_16x16x128_f8f6f4 and mfma_f32_32x32x64_f8f6f4 with f8 input type on MI350. * [API change] Add support for data types of both operands And print mfma instruction name accordingly. For now, mixed precision mfma between 8-bit and 4- or 6-bit is not supported yet. * Support mixed mfma with bf8/fp8 and fp6/bf6/f4 * [API change] Add support for scale * [NFC] Fix format * [API change] Refactor tensor and LDS layout - Support data types - Support both 32 and 64 banks - Still working on LDS accesses * [LDS layout] Add support for ds_read access pattern for TN config - Fixed the issue with maxPhase computation. Need to submit a PR to fix it in the triton compiler - For ds_read_b64 with 64 banks, there are bank conflicts. We need to figure out a different swizzling pattern to avoid bank conflicts. * [LDS layout] Add support for ds_write access pattern Assumed a basic global access pattern * [LDS layout] Support access pattern for MN-contig without using mfma_transpose_load instructions - Elements along the M/N dim are contiguous in both global memory and LDS. Note that this is not the in-thread transpose case. - Swizzling is disabled * [LDS layout] Support access pattern for MN-contig with mfma_trans_load instructions * Clean up the code * [lds layout] support padding * Reduce tex package required
1 parent 35fdfd8 commit 7613c4d

File tree

8 files changed

+1818
-1059
lines changed

8 files changed

+1818
-1059
lines changed

python/perf-kernels/tools/plot-layout/README.md

Lines changed: 89 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,113 +5,160 @@ Here is the help info from the script.
55

66
```bash
77
>$ python3 plot_layout.py -h
8-
usage: Draw triton layouts [-h] [-shape SHAPE SHAPE SHAPE] [-plot {blocked,dot,wmma,lds}] [-nonKDim {16,32}] [-sizePerThread SIZEPERTHREAD SIZEPERTHREAD] [-threadsPerWarp THREADSPERWARP THREADSPERWARP]
9-
[-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-kWidth {4,8,16}] [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}] [-wave_size {32,64}] [-o O] [-mfmaTrans] [-keep]
8+
usage: Draw triton layouts [-h] [-tensorShape TENSORSHAPE TENSORSHAPE] [-dotShape DOTSHAPE DOTSHAPE DOTSHAPE] [-plot {blocked,dot,wmma,lds}] [-dim0 DIM0] [-dim1 DIM1] [-sizePerThread SIZEPERTHREAD SIZEPERTHREAD]
9+
[-threadsPerWarp THREADSPERWARP THREADSPERWARP] [-warpsPerCTA WARPSPERCTA WARPSPERCTA] [-order ORDER ORDER] [-nonKDim {16,32}] [-kWidth {4,8,16,32}] [-kGroup {1,2}]
10+
[-dtype_a {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}] [-dtype_b {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}] [-mfmaTrans] [-scale] [-banks {32,64}] [-lds_layout {swizzle,padding,none}] [-lds_access {read,write,none}]
11+
[-mnContig] [-mfma_trans_load] [-swizzleVec {4,8,16,32}] [-padInterval PADINTERVAL] [-padAmount PADAMOUNT] [-wave_size {32,64}] [-o O] [-keep]
1012

1113
options:
1214
-h, --help show this help message and exit
13-
-shape SHAPE SHAPE SHAPE
14-
Tensor shape in the form of M,N,K
15+
-tensorShape TENSORSHAPE TENSORSHAPE
16+
2D tensor shape in the form of dim0,dim1
17+
-dotShape DOTSHAPE DOTSHAPE DOTSHAPE
18+
Dot op shape in the form of M,N,K
1519
-plot {blocked,dot,wmma,lds}
1620
choose plot mode
17-
-nonKDim {16,32} mfma instruction dim
21+
-dim0 DIM0 tensor dim0 name
22+
-dim1 DIM1 tensor dim1 name
1823
-sizePerThread SIZEPERTHREAD SIZEPERTHREAD
1924
-threadsPerWarp THREADSPERWARP THREADSPERWARP
2025
-warpsPerCTA WARPSPERCTA WARPSPERCTA
2126
-order ORDER ORDER
22-
-kWidth {4,8,16} number of elements per thread
27+
-nonKDim {16,32} mfma instruction dim
28+
-kWidth {4,8,16,32} number of contiguous elements per thread
29+
-kGroup {1,2} total number of elements / kWidth per mfma instruction
30+
-dtype_a {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}
31+
element type of operand A
32+
-dtype_b {fp16,bf16,fp8,bf8,fp6,bf6,f4,i8}
33+
element type of operand B
34+
-mfmaTrans If set, then use mfma.trans layout
35+
-scale If set, plot the scale tensor for mfma_f8f6f4 instructions
36+
-banks {32,64} choose the number of banks in LDS
2337
-lds_layout {swizzle,padding,none}
2438
choose the LDS data layout
2539
-lds_access {read,write,none}
2640
choose LDS access mode
41+
-mnContig If set, the tensor is K x N and n-contig
42+
-mfma_trans_load If set, use MFMA transpose load instructions
43+
-swizzleVec {4,8,16,32}
44+
number of contiguous elements in a vector to swizzle
45+
-padInterval PADINTERVAL
46+
Add padding for every padInterval bytes
47+
-padAmount PADAMOUNT Pad padAmount bytes for every padInterval bytes
2748
-wave_size {32,64} choose the wmma instruction mode
2849
-o O output pdf file name (without surfix)
29-
-mfmaTrans If set, then use mfma.trans layout
3050
-keep If set, keep the generated .tex file
3151
```
3252

3353
## Installation
3454
This script does not require torch or triton to be installed. The only package
3555
it depends on is latex. On Ubuntu, do
3656
```bash
37-
sudo apt install texlive-full
57+
sudo apt-get install texlive-latex-base texlive-latex-extra texlive-fonts-recommended texlive-fonts-extra
58+
3859
```
3960

4061
## Draw blocked layout (`-plot blocked`)
4162

4263
Examples:
4364
```bash
44-
python3 plot_layout.py -plot blocked -shape 128 128 64 -sizePerThread 1 8 -threadsPerWarp 8 8 -warpsPerCTA 4 1
45-
python3 plot_layout.py -plot blocked -shape 16 128 64 -sizePerThread 1 8 -threadsPerWarp 16 4 -warpsPerCTA 1 2
46-
python3 plot_layout.py -plot blocked -shape 32 128 64 -sizePerThread 8 1 -threadsPerWarp 4 16 -warpsPerCTA 1 2 -order 0 1
65+
python3 plot_layout.py -plot blocked -tensorShape 128 64 -sizePerThread 1 8 -threadsPerWarp 8 8 -warpsPerCTA 4 1
66+
python3 plot_layout.py -plot blocked -tensorShape 16 64 -sizePerThread 1 8 -threadsPerWarp 16 4 -warpsPerCTA 1 2
67+
python3 plot_layout.py -plot blocked -tensorShape 32 64 -sizePerThread 8 1 -threadsPerWarp 4 16 -warpsPerCTA 1 2 -order 0 1
4768
```
4869

4970
Blocked layouts are used during global load. It is used to describe the layout of the tensor
5071
for pointers and results.
51-
We can provide tensor shape (`-shape M N K`) and blocked layout parameters (
72+
We can provide tensor shape (`-tensorShape dim0 dim1`) and blocked layout parameters (
5273
`-sizePerThread x y`, `-threadsPerWarp x y`, and `-warpsPerCTA x y`).
5374
We can also provide the order of the tensor as `-order x y` to control which dim
5475
is the fastest changing dimension.
5576

5677
Notes
57-
- All of the gemm dims (M, N, and K) are needed when providing the shape. But only
58-
M and K will be used to plot the layout of the tensor.
5978
- The script does not support the case when threads are loading elements that are
6079
out of the boundary of the tensor dimensions. This means
61-
- For M: sizePerThread[0] * threadsPerWarps[0] * warpsPerCTA[0] <= M
62-
- For K: sizePerThread[1] * threadsPerWarps[1] * warpsPerCTA[1] <= K
80+
- For dim0: sizePerThread[0] * threadsPerWarps[0] * warpsPerCTA[0] <= dim0
81+
- For dim1: sizePerThread[1] * threadsPerWarps[1] * warpsPerCTA[1] <= dim1
6382

6483

6584
## Draw mfma operand and result layouts (`-plot dot`)
6685

6786
Examples:
6887
```bash
69-
python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 4
70-
python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8
71-
python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 32 -kWidth 8 -mfmaTrans
72-
python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 8
73-
python3 plot_layout.py -plot dot -shape 128 128 64 -warpsPerCTA 2 4 -nonKDim 16 -kWidth 16
88+
## i8 inputs
89+
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 8 -dtype_a i8 -dtype_b i8
90+
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -dtype_a i8 -dtype_b i8
91+
## fp16/bf16 inputs
92+
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 4 -dtype_a fp16 -dtype_b fp16
93+
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 8 -dtype_a fp16 -dtype_b fp16
94+
## fp8/bf8 inputs
95+
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 8 -dtype_a fp8 -dtype_b bf8
96+
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -dtype_a fp8 -dtype_b bf8
97+
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -kGroup 2 -dtype_a fp8 -dtype_b bf8
98+
## f4 and fp6/bf6 inputs
99+
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 32 -kGroup 1 -dtype_a f4 -dtype_b bf6
100+
## fp8/bf8 and fp6/bf6/f4 inputs
101+
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -kGroup 2 -dtype_a fp6 -dtype_b bf8
102+
## mixed precision with scaling
103+
python3 plot_layout.py -plot dot -dotShape 128 128 128 -warpsPerCTA 2 4 -kWidth 16 -kGroup 2 -dtype_a fp6 -dtype_b bf8 -scale
74104
```
75105

106+
One can add `-nonKDim [16,32]` and `-mfmaTrans` to all of the above examples.
107+
76108
This mode draws two graphs:
77-
1. The layout of the whole tile for tile A, B, and C
109+
1. The layout of the dot operation, i.e. tile C = tile A x tile B
78110
2. The layout of a single mfma block, operands and results of one or more mfma
79111
instructions that share the same accumulating VGPRs.
80-
This view has thread distributions among tensor elements.
81112

82113
Knobs
83-
- `-kWidth`: the number of elements that will be loaded into one thread at once
84-
- `-nonKDim`: 16 ot 32, which is used to control the mfma instruction size
114+
- `-kWidth [4,8,16,32]`: the number of elements that will be loaded into one thread at once
115+
- `-kGroup [1,2]`: total number of elements / kWidth for on mfma instruction.
116+
This is 1 for all mfma instructions except for mfma_f32_16x16x128_f8f6f4 and mfma_f32_32x32x64_f8f6f4
117+
with fp8 input types (CBSZ=0 or 1 and/or BLGP=0 or 1)
118+
- `-nonKDim [16,32]`: mfma instruction size. The default is set to 16.
85119
- `-mfmaTrans`: if set, the transposed mfma layout will be plotted.
120+
- `-dtype_a` and `-dtype_b`: element types of operand A and B. The default value is fp16.
121+
- `-scale`: plot scale tensors for A and B. This is only supported with f4/f6 and f8 with `kGroup=2`.
122+
If `-scale` is set but not supported, it's ignored.
86123

87124
Notes
88125
- The layout shows the mapping from the threads/wave to the elements in the
89-
original tensor. It does not care if the elements are arranged in LDS, like
90-
swizzling to avoid bank conflicts.
91-
- The script does not allow settings for data type or k dim of the mfma instruction.
92-
This can be controled by the `-kWidth` flag.
93-
- For example, if we want `mfma_32x32x8xf16`, we can set `-nonKDim 32` and `-kWidth 4`.
94-
- If we want `mfma_32x32x16xf8`, we can set `-nonKDim 32` and `-kWidth 8`.
95-
126+
original tensor. It does not matter if LDS is used.
127+
- The script does not allow settings for k dim of the mfma instruction.
128+
This can be controled by the `-kWidth` and `-kGroup`.
96129

97130
## Draw LDS access (`-plot lds`)
98131

99132
Examples:
100133
```bash
101-
python3 plot_layout.py -plot lds -lds_layout none -lds_access none -shape 128 128 64 -kWidth 8
134+
python3 plot_layout.py -plot lds -lds_layout none -lds_access none -tensorShape 128 128 -kWidth 8
135+
python3 plot_layout.py -plot lds -lds_layout none -lds_access none -tensorShape 128 128 -kWidth 32 -dtype_a f4
136+
python3 plot_layout.py -plot lds -lds_layout none -lds_access none -tensorShape 128 128 -kWidth 16 -dtype_a fp8 -banks 64
137+
python3 plot_layout.py -plot lds -lds_layout swizzle -lds_access none -tensorShape 128 128 -kWidth 16 -dtype_a fp8 -banks 64
138+
python3 plot_layout.py -plot lds -lds_layout swizzle -lds_access read -tensorShape 128 128 -kWidth 16 -dtype_a bf8 -banks 64
139+
python3 plot_layout.py -plot lds -lds_layout swizzle -lds_access write -tensorShape 128 128 -kWidth 16 -dtype_a f4 -banks 32
140+
python3 plot_layout.py -plot lds -lds_layout none -lds_access read -tensorShape 128 32 -kWidth 4 -dtype_a fp16 -banks 64 -mnContig
141+
python3 plot_layout.py -plot lds -lds_layout swizzle -lds_access read -tensorShape 128 32 -kWidth 16 -dtype_a fp8 -banks 64 -mnContig -mfma_trans_load
142+
python3 plot_layout.py -plot lds -lds_layout padding -lds_access none -tensorShape 128 32 -kWidth 8 -dtype_a fp16 -banks 32 -padInterval 128 -padAmount 16
102143
```
103144

104145
Knobs
105-
- `kWidth` here means the vector size when accessing LDS
146+
- `kWidth`: the vector size (in unit of elements) when accessing LDS
147+
- `banks`: the number of banks in LDS. (64 for gfx950, 32 for pre-gfx950)
148+
- `dtype_a`: element data type
106149
- Three options for `-lds_layout`:
107150
- `none`: no swizzling, no padding
108-
- `padding`: padding at every 128B
109-
- `swizzling`: apply the swizzling pattern, which is derived from tensor shape and kWidth.
151+
- `swizzle`: apply the swizzling pattern, which is derived from tensor shape and kWidth.
152+
- `padding`: pad `padAmount` bytes for every `padInterval` bytes of data
153+
- `padAmount`: default is 0
154+
- `padInterval`: default is 1
110155
- Three options for `-lds_access`:
111156
- `none`: do not plot access pattern
112-
- `read`: plot accessed elements during ds_read
113-
- `write`: plot accessed elements during ds_write. Note that this needs some infomation from
114-
global load. Therefore, we need to provide `-sizePerThread` and `-threadsPerWarp`.
115-
116-
Notes
117-
- This mode is rarely used. If you have any questions, please contact Lixun Zhang directly.
157+
- `read`: plot accessed elements at the first cycle of ds_read
158+
- `write`: plot accessed elements during ds_write. For global load access, we assume
159+
a fully coalesced dwordx4 access pattern along the K dim.
160+
- `mnContig`: If set, the tile is stored in mn-contig layout. In this layout, elements along
161+
the M/N dim are contiguous in both global memory and LDS.
162+
- `mfma_trans_load`: This flag only works when `mnContig` is set. When set, `ds_read_b64_tr_bx`
163+
instructions are used to read from LDS. Note that current triton LDS layout mechanism will
164+
lead to bank conflicts.
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
\newcommand{\drawBlockedWave}[5]{
2+
%%
3+
%% Draw a wave coverage with blocked layout
4+
%%
5+
%% Wave TL: pre defined top-left coordinate of the wave
6+
%% \elem: pre defined variable
7+
%%
8+
%% #1: sizePerThread[0] --> sizePerThreadM
9+
%% #2: sizePerThread[1] --> sizePerThreadN
10+
%% #3: threadsPerWarp[0] --> threadsPerWarpM
11+
%% #4: threadsPerWarp[1] --> threadsPerWarpN
12+
%% #5: fastest changing dim --> order
13+
14+
\pgfmathsetmacro{\sizePerThreadM}{#1}
15+
\pgfmathsetmacro{\sizePerThreadN}{#2}
16+
\pgfmathsetmacro{\threadsPerWarpM}{#3}
17+
\pgfmathsetmacro{\threadsPerWarpN}{#4}
18+
\pgfmathsetmacro{\order}{#5}
19+
20+
\pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM}
21+
\pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN}
22+
23+
\foreach \tid in {0,...,63}{
24+
\pgfmathsetmacro{\tidM}{int(\tid/\threadsPerWarpN)}
25+
\pgfmathsetmacro{\tidN}{mod(\tid,\threadsPerWarpN)}
26+
\coordinate (Thread TL) at ($(Wave TL)+(\tidN*\sizePerThreadN*\elem, -\tidM*\sizePerThreadM*\elem)$);
27+
\pgfmathsetmacro{\ratio}{\tidM*10}
28+
29+
\ifthenelse{\tid = 0}{
30+
\draw [line width = 0.01mm, fill=red] (Thread TL)
31+
rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem);
32+
}{
33+
\draw [line width = 0.01mm, fill=blue!\ratio!white] (Thread TL)
34+
rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem);
35+
}
36+
}
37+
\draw (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem);
38+
}
39+
40+
\newcommand{\drawBlockedCTA}[7]{
41+
%%
42+
%% Draw a CTA coverage with blocked layout
43+
%%
44+
%% CTA TL: pre defined top-left coordinate of the CTA
45+
%% \elem: pre defined variable
46+
%%
47+
%% #1: sizePerThread[0] --> sizePerThreadM
48+
%% #2: sizePerThread[1] --> sizePerThreadN
49+
%% #3: threadsPerWarp[0] --> threadsPerWarpM
50+
%% #4: threadsPerWarp[1] --> threadsPerWarpN
51+
%% #5: warpsPerCTA[0] --> warpsPerCTAM
52+
%% #6: warpsPerCTA[1] --> warpsPerCTAN
53+
%% #7: fastest changing dim --> order
54+
55+
\pgfmathsetmacro{\sizePerThreadM}{#1}
56+
\pgfmathsetmacro{\sizePerThreadN}{#2}
57+
\pgfmathsetmacro{\threadsPerWarpM}{#3}
58+
\pgfmathsetmacro{\threadsPerWarpN}{#4}
59+
\pgfmathsetmacro{\warpsPerCTAM}{#5}
60+
\pgfmathsetmacro{\warpsPerCTAN}{#6}
61+
\pgfmathsetmacro{\order}{#7}
62+
63+
\pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM}
64+
\pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN}
65+
\pgfmathsetmacro{\waveSizeM}{\sizePerThreadM*\threadsPerWarpM}
66+
\pgfmathsetmacro{\waveSizeN}{\sizePerThreadN*\threadsPerWarpN}
67+
68+
\pgfmathsetmacro{\maxWaveId}{\warpsPerCTAM*\warpsPerCTAN-1}
69+
70+
\coordinate (Wave TL) at (CTA TL);
71+
\drawBlockedWave{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\order}
72+
\foreach \waveId in {0,...,\maxWaveId}{
73+
\ifthenelse{\order=1}
74+
{
75+
\pgfmathsetmacro{\waveCoordM}{int(\waveId/\warpsPerCTAN)}
76+
\pgfmathsetmacro{\waveCoordN}{mod(\waveId,\warpsPerCTAN)}
77+
\pgfmathsetmacro{\rot}{0}
78+
}{
79+
\pgfmathsetmacro{\waveCoordM}{mod(\waveId,\warpsPerCTAM)}
80+
\pgfmathsetmacro{\waveCoordN}{int(\waveId/\warpsPerCTAM)}
81+
\pgfmathsetmacro{\rot}{90}
82+
}
83+
84+
\coordinate (Wave TL) at ($(CTA TL)+(\waveCoordN*\waveSizeN*\elem, -\waveCoordM*\waveSizeM*\elem)$);
85+
\draw [ultra thin] (Wave TL) rectangle ++(\waveSizeN*\elem, -\waveSizeM*\elem)
86+
node [pos=.5, scale=.6*\scale, inner sep=0, fill=white, rotate=\rot] {wave\waveId};
87+
}
88+
89+
\draw [thick] (CTA TL) rectangle ++(\CTASizeN*\elem, -\CTASizeM*\elem);
90+
}
91+
92+
\newcommand{\drawBlockedTensor}[8]{
93+
%%
94+
%% Draw a tensor with blocked layout of the following parameters
95+
%% sizePerThread[2]
96+
%% threadsPerWarp[2]
97+
%% warpsPerCTA[2]
98+
%% order[2]
99+
%%
100+
%% TL: pre defined top-left coordinate of the tensor
101+
%% \elem: pre defined variable
102+
%% \dimColName: dim0Name
103+
%% \dimRowName: dim1Name
104+
%%
105+
%% #1: tensorShape[0] --> M
106+
%% #2: tensorShape[1] --> N
107+
%% #3: sizePerThread[0] --> sizePerThreadM
108+
%% #4: sizePerThread[1] --> sizePerThreadN
109+
%% #5: threadsPerWarp[0] --> threadsPerWarpM
110+
%% Note that threadsPerWarp[1] is calculated by 64/threadsPerWarp[0]
111+
%% #6: warpsPerCTA[0] --> warpsPerCTAM
112+
%% #7: warpsPerCTA[1] --> warpsPerCTAN
113+
%% #8: fastest changing dim --> order
114+
115+
\pgfmathsetmacro{\M}{#1}
116+
\pgfmathsetmacro{\N}{#2}
117+
\pgfmathsetmacro{\sizePerThreadM}{#3}
118+
\pgfmathsetmacro{\sizePerThreadN}{#4}
119+
\pgfmathsetmacro{\threadsPerWarpM}{#5}
120+
\pgfmathsetmacro{\warpsPerCTAM}{#6}
121+
\pgfmathsetmacro{\warpsPerCTAN}{#7}
122+
\pgfmathsetmacro{\order}{#8}
123+
124+
\pgfmathsetmacro{\threadsPerWarpN}{64/\threadsPerWarpM}
125+
\pgfmathsetmacro{\CTASizeM}{\sizePerThreadM*\threadsPerWarpM*\warpsPerCTAM}
126+
\pgfmathsetmacro{\CTASizeN}{\sizePerThreadN*\threadsPerWarpN*\warpsPerCTAN}
127+
\pgfmathsetmacro{\CTARepM}{\M/\CTASizeM}
128+
\pgfmathsetmacro{\CTARepN}{\N/\CTASizeN}
129+
\pgfmathsetmacro{\maxCTAId}{\CTARepM*\CTARepN-1}
130+
131+
\foreach \ctaId in {0,...,\maxCTAId}{
132+
\pgfmathsetmacro{\ctaCoordM}{int(\ctaId/\CTARepN)}
133+
\pgfmathsetmacro{\ctaCoordN}{mod(\ctaId,\CTARepN)}
134+
\coordinate (CTA TL) at ($(TL)+(\ctaCoordN*\CTASizeN*\elem, -\ctaCoordM*\CTASizeM*\elem)$);
135+
\drawBlockedCTA{\sizePerThreadM}{\sizePerThreadN}{\threadsPerWarpM}{\threadsPerWarpN}{\warpsPerCTAM}{\warpsPerCTAN}{\order}
136+
}
137+
138+
\node [scale=.7*\scale, above, rotate=90] at ($(TL)+(0, -.5*\M*\elem)$) {\dimColName=\M};
139+
\node [scale=.7*\scale, above] at ($(TL)+(.5*\N*\elem, 0)$) {\dimRowName=\N};
140+
141+
\def\zoomR{1.5}
142+
\coordinate (zoomin BL) at ($(TL)+(0, .3)$);
143+
144+
\foreach \hl in {0,...,\sizePerThreadM}{
145+
\draw ($(zoomin BL)+(0, \hl*\elem*\zoomR)$) -- ++(\sizePerThreadN*\elem*\zoomR,0);
146+
}
147+
\foreach \vl in {0,...,\sizePerThreadN}{
148+
\draw ($(zoomin BL)+(\vl*\elem*\zoomR, 0)$) -- ++(0, \sizePerThreadM*\elem*\zoomR);
149+
}
150+
151+
\node [scale=.6*\scale, left] at ($(zoomin BL)+(0, .5*\sizePerThreadM*\elem*\zoomR)$) {$t_0$};
152+
\node [scale=.6*\scale, right] at ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, .5*\sizePerThreadM*\elem*\zoomR)$) {\sizePerThreadM$\times$\sizePerThreadN};
153+
154+
\draw [densely dotted] (TL) -- (zoomin BL);
155+
\draw [densely dotted] ($(TL)+(\sizePerThreadN*\elem, 0)$) -- ($(zoomin BL)+(\sizePerThreadN*\elem*\zoomR, 0)$);
156+
\draw [fill=red] (TL) rectangle ++(\sizePerThreadN*\elem, -\sizePerThreadM*\elem);
157+
}

0 commit comments

Comments
 (0)