Skip to content

Commit 1a8a0a7

Browse files
committed
Update the DPAS encoding documents.
1 parent 76c054e commit 1a8a0a7

File tree

1 file changed

+143
-38
lines changed

1 file changed

+143
-38
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 143 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ def DpasEncodingAttr : DistributedEncoding<"DpasEncoding", "intel_dpas_encoding"
1414
let mnemonic = "dpas";
1515

1616
let description = [{
17-
An encoding for the tensors distributed across the threads for the C and D operands of XMX tensor core operation.
17+
An encoding for the tensors distributed across the threads for the C and D operands of XMX tensor core operation
18+
and its corresponding A and B operands layout with the DPAS encoding as parent.
1819
The XMX tensor core operation is defined for matrix matmul as: D=A*B+C
1920
The shape of the of XMX tensor core operation is defined by systolic depth, repeat count, execution size and operations per channel.
2021

@@ -23,43 +24,147 @@ The encoding is characterized by parameters:
2324
- `systolicDepth` For PVC/ATSM, the size is 8.
2425
- `executionSize` For PVC, the size is 16. For ATSM, the size is 8.
2526
- `opsPerChannel` 4 for 8 bit scalar type, 2 for 16 bit scalar type, 1 for 32 bit scalar type.
26-
- `warpsPerCTA`
27-
- `sugGroupSize` valid sub group size is 8/16/32
28-
29-
30-
The layout example repeat_count=8, systolic_depth=8, execution_size=16 and operands_per_chan=2 for warp size 32.
31-
For A operand:
32-
systolic depth = 8
33-
<------------------------------------------------------------------------------------------------->
34-
opsPerChan=2
35-
<--------->
36-
t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 ^
37-
t8 ... t8 t9 ... t9 t10 ... t10 t11 ... t11 t12 ... t12 t13 ... t13 t14 ... t14 t15 ... t15 |
38-
t16 ... t16 t17 ... t17 t18 ... t18 t19 ... t19 t20 ... t20 t21 ... t21 t22 ... t22 t23 ... t23 |
39-
t24 ... t24 t25 ... t25 t26 ... t26 t27 ... t27 t28 ... t28 t29 ... t29 t30 ... t30 t31 ... t31 | repeat count <= 8
40-
t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 |
41-
t8 ... t8 t9 ... t9 t10 ... t10 t11 ... t11 t12 ... t12 t13 ... t13 t14 ... t14 t15 ... t15 |
42-
t16 ... t16 t17 ... t17 t18 ... t18 t19 ... t19 t20 ... t20 t21 ... t21 t22 ... t22 t23 ... t23 |
43-
t24 ... t24 t25 ... t25 t26 ... t26 t27 ... t27 t28 ... t28 t29 ... t29 t30 ... t30 t31 ... t31 v
44-
45-
For B operand:
46-
execution size = 16
47-
<------------------------------------------------------------->
48-
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^ ^
49-
. . . . . . . . . . . . . . . . | opsPerChan=2|
50-
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v |
51-
t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 |
52-
. . . . . . . . . . . . . . . . |
53-
t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 | systolic depth = 8
54-
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
55-
. . . . . . . . . . . . . . . . |
56-
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
57-
t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 |
58-
. . . . . . . . . . . . . . . . |
59-
t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 v
60-
61-
This pattern repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks
62-
along the row (resp. col) dimension.
27+
- `warpsPerCTA` indicates the distribution of the warps in the block. The order is [1, 0] for rank 2.
28+
- `repCluster` indicates the cluster size of the repetitions of the DPAS tile.
29+
- `sugGroupSize` Currently only sub group size 16 is supported.
30+
31+
The values of the matrix is distributed across the threads in the subgroup as row-major order.
32+
- If the column size of the matrix is equal to the number of threads in the subgroup, a single value name represents a single rows of the matrix.
33+
- If the column size of the matrix is less than the number of threads in the subgroup, a single value name represents multiple rows of the matrix.
34+
- If the column size of the matrix is larger than the number of the threads in the subgroup, a single row of the matrix requires multiple value name.
35+
36+
Example 1, the column size of the matrix is 16 and the number of threads in the subgroup is 16.
37+
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and sugGroupSize=16.
38+
39+
The layout for A operand:
40+
K = 16 (K = systolic depth * opsPerChan)
41+
<---------------------------------------------------------------------------->
42+
43+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^
44+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
45+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
46+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
47+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (repeat count)
48+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
49+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
50+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v
51+
52+
The layout for B operand:
53+
N = 16 (N = execution size)
54+
<---------------------------------------------------------------------------->
55+
56+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^
57+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
58+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
59+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
60+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
61+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
62+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
63+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | K = 16 (K = systolic depth * opsPerChan)
64+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
65+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
66+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
67+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
68+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
69+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
70+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
71+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v
72+
73+
The layout for C operand and result D:
74+
N = 16 (N = execution size)
75+
<---------------------------------------------------------------------------->
76+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^
77+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
78+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
79+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
80+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (M=repeat count)
81+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
82+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
83+
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v
84+
85+
Example 2, the column size of the matrix is 8 and the number of threads in the subgroup is 16.
86+
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and sugGroupSize=16.
87+
88+
The layout for A operand:
89+
K = 8 (K = systolic depth * opsPerChan)
90+
<---------------------------------------->
91+
92+
t0 t1 t2 t3 t4 t5 t6 t7 ^
93+
t8 t9 t10 t11 t12 t13 t14 t15 |
94+
t0 t1 t2 t3 t4 t5 t6 t7 |
95+
t8 t9 t10 t11 t12 t13 t14 t15 |
96+
t0 t1 t2 t3 t4 t5 t6 t7 | M = 8 (repeat count)
97+
t8 t9 t10 t11 t12 t13 t14 t15 |
98+
t0 t1 t2 t3 t4 t5 t6 t7 |
99+
t8 t9 t10 t11 t12 t13 t14 t15 v
100+
101+
The layouts for B operand is like the one of opsPerChan=2 but the K size is 8.
102+
The layouts for C and D operands are same as the one of opsPerChan=2.
103+
104+
Example 3, the column size of the matrix is 32 and the number of threads in the subgroup is 16.
105+
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and sugGroupSize=16.
106+
107+
The layout for A operand:
108+
K = 32 (K = systolic depth * opsPerChan)
109+
<----------------------------------------------------------------------------------------------------------------------------------->
110+
111+
t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 ^
112+
t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 |
113+
t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 |
114+
t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 |
115+
t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | M = 8 (repeat count)
116+
t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 |
117+
t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 |
118+
t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 v
119+
120+
The layouts for B operand is like the one of opsPerChan=2 but the K size is 32.
121+
The layouts for C and D operands are same as the one of opsPerChan=2.
122+
123+
The patterns (illustrated above) repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks
124+
along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing.
125+
126+
Suppose we have a `tt.dot` operation of the block size [64, 128] += [64, 32] * [32, 128] of hf16/bf16.
127+
The `warpsPerCTA` set to [2, 2]. The number of repetitions of the DPAS tile per warp is: A=8, B=8, C,D=16.
128+
The DPAS repetitions are distributed as follows:
129+
130+
warp[:0] warp[:1] warp[:0] warp[:1]
131+
|----^----|----^----|----^----|----^----|
132+
repCluster[1]
133+
<--------->
134+
┌────┬────┬────┬────┬────┬────┬────┬────┐
135+
│R0 │R1 │ │ │R4 │R5 │ │ │
136+
│ │ │ │ │ │ │ │ │
137+
├────┼────┼────┼────┼────┼────┼────┼────┤
138+
│R2 │R3 │ │ │R6 │R7 │ │ │
139+
│ │ │ │ │ │ │ │ │
140+
└────┴────┴────┴────┴────┴────┴────┴────┘
141+
142+
- ^ ┌────┬────┐ ┌────┬────┬────┬────┬────┬────┬────┬────┐
143+
| | │R0 │R2 │ │R0 │R1 │ │ │R4 │R5 │ │ │
144+
| | │ │ │ │ │ │ │ │ │ │ │ │
145+
warp[0:] < repCluster[0] | ]────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
146+
| | │R1 │R3 │ │R2 │R3 │ │ │R6 │R7 │ │ │
147+
| | │ │ │ │ │ │ │ │ │ │ │ │
148+
- v ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
149+
| │ │ │ │ │ │ │ │ │ │ │ │
150+
| │ │ │ │ │ │ │ │ │ │ │ │
151+
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
152+
| │ │ │ │ │ │ │ │ │ │ │ │
153+
| │ │ │ │ │ │ │ │ │ │ │ │
154+
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
155+
| │R4 │R6 │ │R8 │R9 │ │ │R12 │R13 │ │ │
156+
| │ │ │ │ │ │ │ │ │ │ │ │
157+
warp[0:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
158+
| │R5 │R7 │ │R10 │R11 │ │ │R14 │R15 │ │ │
159+
| │ │ │ │ │ │ │ │ │ │ │ │
160+
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
161+
| │ │ │ │ │ │ │ │ │ │ │ │
162+
| │ │ │ │ │ │ │ │ │ │ │ │
163+
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
164+
| │ │ │ │ │ │ │ │ │ │ │ │
165+
| │ │ │ │ │ │ │ │ │ │ │ │
166+
- └────┴────┘ └────┴────┴────┴────┴────┴────┴────┴────┘
167+
63168
}];
64169

65170
let parameters = (

0 commit comments

Comments
 (0)