@@ -14,7 +14,8 @@ def DpasEncodingAttr : DistributedEncoding<"DpasEncoding", "intel_dpas_encoding"
1414 let mnemonic = "dpas";
1515
1616 let description = [{
17- An encoding for the tensors distributed across the threads for the C and D operands of XMX tensor core operation.
17+ An encoding for the tensors distributed across the threads for the C and D operands of XMX tensor core operation
18+ and its corresponding A and B operands layout with the DPAS encoding as parent.
1819The XMX tensor core operation is defined for matrix matmul as: D=A*B+C
1920The shape of the of XMX tensor core operation is defined by systolic depth, repeat count, execution size and operations per channel.
2021
@@ -23,43 +24,147 @@ The encoding is characterized by parameters:
2324 - `systolicDepth` For PVC/ATSM, the size is 8.
2425 - `executionSize` For PVC, the size is 16. For ATSM, the size is 8.
2526 - `opsPerChannel` 4 for 8 bit scalar type, 2 for 16 bit scalar type, 1 for 32 bit scalar type.
26- - `warpsPerCTA`
27- - `sugGroupSize` valid sub group size is 8/16/32
28-
29-
30- The layout example repeat_count=8, systolic_depth=8, execution_size=16 and operands_per_chan=2 for warp size 32.
31- For A operand:
32- systolic depth = 8
33- <------------------------------------------------------------------------------------------------->
34- opsPerChan=2
35- <--------->
36- t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 ^
37- t8 ... t8 t9 ... t9 t10 ... t10 t11 ... t11 t12 ... t12 t13 ... t13 t14 ... t14 t15 ... t15 |
38- t16 ... t16 t17 ... t17 t18 ... t18 t19 ... t19 t20 ... t20 t21 ... t21 t22 ... t22 t23 ... t23 |
39- t24 ... t24 t25 ... t25 t26 ... t26 t27 ... t27 t28 ... t28 t29 ... t29 t30 ... t30 t31 ... t31 | repeat count <= 8
40- t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 |
41- t8 ... t8 t9 ... t9 t10 ... t10 t11 ... t11 t12 ... t12 t13 ... t13 t14 ... t14 t15 ... t15 |
42- t16 ... t16 t17 ... t17 t18 ... t18 t19 ... t19 t20 ... t20 t21 ... t21 t22 ... t22 t23 ... t23 |
43- t24 ... t24 t25 ... t25 t26 ... t26 t27 ... t27 t28 ... t28 t29 ... t29 t30 ... t30 t31 ... t31 v
44-
45- For B operand:
46- execution size = 16
47- <------------------------------------------------------------->
48- t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^ ^
49- . . . . . . . . . . . . . . . . | opsPerChan=2|
50- t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v |
51- t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 |
52- . . . . . . . . . . . . . . . . |
53- t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 | systolic depth = 8
54- t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
55- . . . . . . . . . . . . . . . . |
56- t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
57- t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 |
58- . . . . . . . . . . . . . . . . |
59- t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 v
60-
61- This pattern repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks
62- along the row (resp. col) dimension.
27+ - `warpsPerCTA` indicates the distribution of the warps in the block. The order is [1, 0] for rank 2.
28+ - `repCluster` indicates the cluster size of the repetitions of the DPAS tile.
29+ - `sugGroupSize` Currently only sub group size 16 is supported.
30+
31+ The values of the matrix is distributed across the threads in the subgroup as row-major order.
32+ - If the column size of the matrix is equal to the number of threads in the subgroup, a single value name represents a single rows of the matrix.
33+ - If the column size of the matrix is less than the number of threads in the subgroup, a single value name represents multiple rows of the matrix.
34+ - If the column size of the matrix is larger than the number of the threads in the subgroup, a single row of the matrix requires multiple value name.
35+
36+ Example 1, the column size of the matrix is 16 and the number of threads in the subgroup is 16.
37+ The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and sugGroupSize=16.
38+
39+ The layout for A operand:
40+ K = 16 (K = systolic depth * opsPerChan)
41+ <---------------------------------------------------------------------------->
42+
43+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^
44+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
45+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
46+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
47+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (repeat count)
48+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
49+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
50+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v
51+
52+ The layout for B operand:
53+ N = 16 (N = execution size)
54+ <---------------------------------------------------------------------------->
55+
56+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^
57+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
58+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
59+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
60+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
61+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
62+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
63+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | K = 16 (K = systolic depth * opsPerChan)
64+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
65+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
66+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
67+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
68+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
69+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
70+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
71+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v
72+
73+ The layout for C operand and result D:
74+ N = 16 (N = execution size)
75+ <---------------------------------------------------------------------------->
76+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^
77+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
78+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
79+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
80+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (M=repeat count)
81+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
82+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
83+ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v
84+
85+ Example 2, the column size of the matrix is 8 and the number of threads in the subgroup is 16.
86+ The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and sugGroupSize=16.
87+
88+ The layout for A operand:
89+ K = 8 (K = systolic depth * opsPerChan)
90+ <---------------------------------------->
91+
92+ t0 t1 t2 t3 t4 t5 t6 t7 ^
93+ t8 t9 t10 t11 t12 t13 t14 t15 |
94+ t0 t1 t2 t3 t4 t5 t6 t7 |
95+ t8 t9 t10 t11 t12 t13 t14 t15 |
96+ t0 t1 t2 t3 t4 t5 t6 t7 | M = 8 (repeat count)
97+ t8 t9 t10 t11 t12 t13 t14 t15 |
98+ t0 t1 t2 t3 t4 t5 t6 t7 |
99+ t8 t9 t10 t11 t12 t13 t14 t15 v
100+
101+ The layouts for B operand is like the one of opsPerChan=2 but the K size is 8.
102+ The layouts for C and D operands are same as the one of opsPerChan=2.
103+
104+ Example 3, the column size of the matrix is 32 and the number of threads in the subgroup is 16.
105+ The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and sugGroupSize=16.
106+
107+ The layout for A operand:
108+ K = 32 (K = systolic depth * opsPerChan)
109+ <----------------------------------------------------------------------------------------------------------------------------------->
110+
111+ t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 ^
112+ t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 |
113+ t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 |
114+ t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 |
115+ t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | M = 8 (repeat count)
116+ t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 |
117+ t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 |
118+ t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 v
119+
120+ The layouts for B operand is like the one of opsPerChan=2 but the K size is 32.
121+ The layouts for C and D operands are same as the one of opsPerChan=2.
122+
123+ The patterns (illustrated above) repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks
124+ along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing.
125+
126+ Suppose we have a `tt.dot` operation of the block size [64, 128] += [64, 32] * [32, 128] of hf16/bf16.
127+ The `warpsPerCTA` set to [2, 2]. The number of repetitions of the DPAS tile per warp is: A=8, B=8, C,D=16.
128+ The DPAS repetitions are distributed as follows:
129+
130+ warp[:0] warp[:1] warp[:0] warp[:1]
131+ |----^----|----^----|----^----|----^----|
132+ repCluster[1]
133+ <--------->
134+ ┌────┬────┬────┬────┬────┬────┬────┬────┐
135+ │R0 │R1 │ │ │R4 │R5 │ │ │
136+ │ │ │ │ │ │ │ │ │
137+ ├────┼────┼────┼────┼────┼────┼────┼────┤
138+ │R2 │R3 │ │ │R6 │R7 │ │ │
139+ │ │ │ │ │ │ │ │ │
140+ └────┴────┴────┴────┴────┴────┴────┴────┘
141+
142+ - ^ ┌────┬────┐ ┌────┬────┬────┬────┬────┬────┬────┬────┐
143+ | | │R0 │R2 │ │R0 │R1 │ │ │R4 │R5 │ │ │
144+ | | │ │ │ │ │ │ │ │ │ │ │ │
145+ warp[0:] < repCluster[0] | ]────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
146+ | | │R1 │R3 │ │R2 │R3 │ │ │R6 │R7 │ │ │
147+ | | │ │ │ │ │ │ │ │ │ │ │ │
148+ - v ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
149+ | │ │ │ │ │ │ │ │ │ │ │ │
150+ | │ │ │ │ │ │ │ │ │ │ │ │
151+ warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
152+ | │ │ │ │ │ │ │ │ │ │ │ │
153+ | │ │ │ │ │ │ │ │ │ │ │ │
154+ - ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
155+ | │R4 │R6 │ │R8 │R9 │ │ │R12 │R13 │ │ │
156+ | │ │ │ │ │ │ │ │ │ │ │ │
157+ warp[0:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
158+ | │R5 │R7 │ │R10 │R11 │ │ │R14 │R15 │ │ │
159+ | │ │ │ │ │ │ │ │ │ │ │ │
160+ - ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
161+ | │ │ │ │ │ │ │ │ │ │ │ │
162+ | │ │ │ │ │ │ │ │ │ │ │ │
163+ warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
164+ | │ │ │ │ │ │ │ │ │ │ │ │
165+ | │ │ │ │ │ │ │ │ │ │ │ │
166+ - └────┴────┘ └────┴────┴────┴────┴────┴────┴────┴────┘
167+
63168}];
64169
65170 let parameters = (
0 commit comments