Skip to content

Commit bfccedd

Browse files
committed
Update the documents based on review comments.
1 parent 1a8a0a7 commit bfccedd

File tree

6 files changed

+34
-27
lines changed

6 files changed

+34
-27
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,16 @@ The encoding is characterized by parameters:
2626
- `opsPerChannel` 4 for 8 bit scalar type, 2 for 16 bit scalar type, 1 for 32 bit scalar type.
2727
- `warpsPerCTA` indicates the distribution of the warps in the block. The order is [1, 0] for rank 2.
2828
- `repCluster` indicates the cluster size of the repetitions of the DPAS tile.
29-
- `sugGroupSize` Currently only sub group size 16 is supported.
29+
- `threadsPerWarp_` AKA threadsPerWarp. It conflicts with the getThreadsPerWarp in DistributedLayout interface .
30+
We use the name threadsPerWarp_ here. Currently only 16 is supported.
3031

3132
The values of the matrix is distributed across the threads in the subgroup as row-major order.
3233
- If the column size of the matrix is equal to the number of threads in the subgroup, a single value name represents a single rows of the matrix.
3334
- If the column size of the matrix is less than the number of threads in the subgroup, a single value name represents multiple rows of the matrix.
3435
- If the column size of the matrix is larger than the number of the threads in the subgroup, a single row of the matrix requires multiple value name.
3536

3637
Example 1, the column size of the matrix is 16 and the number of threads in the subgroup is 16.
37-
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and sugGroupSize=16.
38+
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and threadsPerWarp=16.
3839

3940
The layout for A operand:
4041
K = 16 (K = systolic depth * opsPerChan)
@@ -83,7 +84,7 @@ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15
8384
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v
8485

8586
Example 2, the column size of the matrix is 8 and the number of threads in the subgroup is 16.
86-
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and sugGroupSize=16.
87+
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and threadsPerWarp=16.
8788

8889
The layout for A operand:
8990
K = 8 (K = systolic depth * opsPerChan)
@@ -102,7 +103,7 @@ The layouts for B operand is like the one of opsPerChan=2 but the K size is 8.
102103
The layouts for C and D operands are same as the one of opsPerChan=2.
103104

104105
Example 3, the column size of the matrix is 32 and the number of threads in the subgroup is 16.
105-
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and sugGroupSize=16.
106+
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and threadsPerWarp=16.
106107

107108
The layout for A operand:
108109
K = 32 (K = systolic depth * opsPerChan)
@@ -121,15 +122,21 @@ The layouts for B operand is like the one of opsPerChan=2 but the K size is 32.
121122
The layouts for C and D operands are same as the one of opsPerChan=2.
122123

123124
The patterns (illustrated above) repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks
124-
along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing.
125+
along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing.
125126

126-
Suppose we have a `tt.dot` operation of the block size [64, 128] += [64, 32] * [32, 128] of hf16/bf16.
127-
The `warpsPerCTA` set to [2, 2]. The number of repetitions of the DPAS tile per warp is: A=8, B=8, C,D=16.
128-
The DPAS repetitions are distributed as follows:
127+
Suppose we have a `tt.dot` operation of the block size [64, 128] = [64, 32] * [32, 128] of f16/bf16. And its input tensor layout is defined as follows:
128+
```
129+
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 2]}>
130+
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#dpas, kWidth=2}>
131+
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
129132

130-
warp[:0] warp[:1] warp[:0] warp[:1]
133+
%d = tt.dot %a, %b, %c : tensor<64x32xf16, #dot_operand_a> * tensor<32x128xf16, #dot_operand_b> -> tensor<64x128xf32, #dpas>
134+
```
135+
The semantic of this `tt.dot` includes GEMM tiling configuration as:
136+
137+
warp[:,0] warp[:,1] warp[:,0] warp[:,1]
131138
|----^----|----^----|----^----|----^----|
132-
repCluster[1]
139+
repCluster[1]
133140
<--------->
134141
┌────┬────┬────┬────┬────┬────┬────┬────┐
135142
│R0 │R1 │ │ │R4 │R5 │ │ │
@@ -142,25 +149,25 @@ The DPAS repetitions are distributed as follows:
142149
- ^ ┌────┬────┐ ┌────┬────┬────┬────┬────┬────┬────┬────┐
143150
| | │R0 │R2 │ │R0 │R1 │ │ │R4 │R5 │ │ │
144151
| | │ │ │ │ │ │ │ │ │ │ │ │
145-
warp[0:] < repCluster[0] | ]────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
152+
warp[0,:] < repCluster[0] | ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
146153
| | │R1 │R3 │ │R2 │R3 │ │ │R6 │R7 │ │ │
147154
| | │ │ │ │ │ │ │ │ │ │ │ │
148155
- v ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
149156
| │ │ │ │ │ │ │ │ │ │ │ │
150157
| │ │ │ │ │ │ │ │ │ │ │ │
151-
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
158+
warp[1,:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
152159
| │ │ │ │ │ │ │ │ │ │ │ │
153160
| │ │ │ │ │ │ │ │ │ │ │ │
154161
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
155162
| │R4 │R6 │ │R8 │R9 │ │ │R12 │R13 │ │ │
156163
| │ │ │ │ │ │ │ │ │ │ │ │
157-
warp[0:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
164+
warp[0,:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
158165
| │R5 │R7 │ │R10 │R11 │ │ │R14 │R15 │ │ │
159166
| │ │ │ │ │ │ │ │ │ │ │ │
160167
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
161168
| │ │ │ │ │ │ │ │ │ │ │ │
162169
| │ │ │ │ │ │ │ │ │ │ │ │
163-
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
170+
warp[1,:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
164171
| │ │ │ │ │ │ │ │ │ │ │ │
165172
| │ │ │ │ │ │ │ │ │ │ │ │
166173
- └────┴────┘ └────┴────┴────┴────┴────┴────┴────┴────┘
@@ -175,7 +182,7 @@ The DPAS repetitions are distributed as follows:
175182
"unsigned":$opsPerChannel,
176183
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
177184
ArrayRefParameter<"unsigned">:$repCluster,
178-
"unsigned":$subGroupSize
185+
"unsigned":$threadsPerWarp_
179186
);
180187

181188
let extraClassDeclaration = extraDistributedDeclaration # [{

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ SmallVector<unsigned> DpasEncodingAttr::getShapeC() const {
134134
SmallVector<unsigned> DpasEncodingAttr::getSizePerThread() const {
135135
size_t rank = getWarpsPerCTA().size();
136136
SmallVector<unsigned> res(rank, 1);
137-
unsigned threadsPerWarp = getSubGroupSize();
137+
unsigned threadsPerWarp = getThreadsPerWarp_();
138138
auto shapeC = getDPASInstShapeC();
139139
unsigned elemsNum = product<unsigned>(shapeC);
140140
unsigned elemsPerThread = elemsNum / threadsPerWarp;
@@ -260,7 +260,7 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand(
260260
ArrayRef<int64_t> shape, mlir::Type eltTy, int kWidth, int opIdx) const {
261261
auto shapePerCTA = getShapePerCTA(*this, shape);
262262
auto rep = getDPASRepetitions(shapePerCTA, opIdx);
263-
auto threadsPerWar = getSubGroupSize();
263+
auto threadsPerWar = getThreadsPerWarp_();
264264
size_t rank = shape.size();
265265
if (opIdx == 0) {
266266
auto shapeA = getShapeA();
@@ -296,7 +296,7 @@ SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
296296
size_t rank = getWarpsPerCTA().size();
297297
SmallVector<unsigned> res(rank, 1);
298298
auto executionSize = getExecutionSize();
299-
auto subGroupSize = getSubGroupSize();
299+
auto subGroupSize = getThreadsPerWarp_();
300300
if (subGroupSize < executionSize) {
301301
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not be "
302302
"smaller than the execution size");
@@ -340,7 +340,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
340340
assert((rank == 2 || rank == 3) && "unexpected rank number for Dpas layout");
341341
if (opIdx == 0) {
342342
SmallVector<unsigned> shapeA = getDPASInstShapeA();
343-
unsigned subGroupSize = getSubGroupSize();
343+
unsigned subGroupSize = getThreadsPerWarp_();
344344
unsigned opsPerChannel = getOpsPerChannel();
345345

346346
// pack the value to i16 for scalar bit width <=16.
@@ -359,7 +359,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
359359

360360
if (opIdx == 1) {
361361
auto shapeB = getShapeB();
362-
auto subGroupSize = getSubGroupSize();
362+
auto subGroupSize = getThreadsPerWarp_();
363363
auto executionSize = getExecutionSize();
364364
if (subGroupSize < executionSize) {
365365
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not "
@@ -394,7 +394,7 @@ SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
394394
assert(rank == 2 || rank == 3);
395395
SmallVector<unsigned> contigPerThread(rank, 1);
396396

397-
unsigned threadsPerWarp = getSubGroupSize();
397+
unsigned threadsPerWarp = getThreadsPerWarp_();
398398
auto instShapeC = getDPASInstShapeC();
399399
// The software vectorization vectorized the value as C array: int a[N] -> int
400400
// a[N][threadsPerWarp]
@@ -506,7 +506,7 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const {
506506
<< "systolicDepth = " << getSystolicDepth() << ", "
507507
<< "executionSize = " << getExecutionSize() << ", "
508508
<< "opsPerChan = " << getOpsPerChannel() << ", "
509-
<< "threadsPerWarp = " << getSubGroupSize() << ", "
509+
<< "threadsPerWarp = " << getThreadsPerWarp_() << ", "
510510
<< "warpsPerCTA = [" << llvm::ArrayRef<unsigned>(warpsPerCTA) << "], "
511511
<< "repCluster = [" << repCluster << "], "
512512
<< "A = [" << rA << "], "

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ struct ConvertLayoutOpConversion
334334
size_t totalElems = elems.size();
335335
auto numElemsPerOperand =
336336
product<unsigned>(dpasLayout.getDPASInstShapeC()) /
337-
dpasLayout.getSubGroupSize();
337+
dpasLayout.getThreadsPerWarp_();
338338
Type elemTy =
339339
this->getTypeConverter()->convertType(srcType.getElementType());
340340
VectorType dotOpTy = vec_ty(elemTy, numElemsPerOperand);

third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class DotOpDPASConversionHelper {
3636
Type i16Ty = type::i16Ty(ctx);
3737
Type s32Ty = IntegerType::get(ctx, 32, IntegerType::Signed);
3838

39-
unsigned threadsPerWarp = layout.getSubGroupSize();
39+
unsigned threadsPerWarp = layout.getThreadsPerWarp_();
4040
unsigned opsPerChannel = layout.getOpsPerChannel();
4141
SmallVector<unsigned> shapeC = layout.getDPASInstShapeC();
4242
unsigned elemNumC = product<unsigned>(shapeC) / threadsPerWarp;

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout,
168168
sizePerThreads[rank - 2] / repCluster[rank - 2],
169169
sizePerThreads[rank - 1] / repCluster[rank - 1]};
170170

171-
unsigned rowsPerElem = dpasLayout.getSubGroupSize() / instShapeC[1];
171+
unsigned rowsPerElem = dpasLayout.getThreadsPerWarp_() / instShapeC[1];
172172
unsigned colsPerElem = 1;
173173

174174
unsigned repNumber = product<unsigned>(repCluster);

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,10 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
176176
// We want to transpose matrices of N*threads_per_warpxthreads_per_warp
177177
// shape.
178178
if ( // X axis condition
179-
encoding.getExecutionSize() != encoding.getSubGroupSize() ||
179+
encoding.getExecutionSize() != encoding.getThreadsPerWarp_() ||
180180
// Y axis conditions
181181
(encoding.getRepeatCount() * encoding.getRepCluster()[0]) %
182-
encoding.getSubGroupSize() !=
182+
encoding.getThreadsPerWarp_() !=
183183
0)
184184
return failure();
185185

0 commit comments

Comments
 (0)