Skip to content

Commit e7c373e

Browse files
committed
Update the documents based on review comments.
1 parent eda4c40 commit e7c373e

File tree

6 files changed

+75
-63
lines changed

6 files changed

+75
-63
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 60 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,16 @@ The encoding is characterized by parameters:
2626
- `opsPerChannel` 4 for 8 bit scalar type, 2 for 16 bit scalar type, 1 for 32 bit scalar type.
2727
- `warpsPerCTA` indicates the distribution of the warps in the block. The order is [1, 0] for rank 2.
2828
- `repCluster` indicates the cluster size of the repetitions of the DPAS tile.
29-
- `sugGroupSize` Currently only sub group size 16 is supported.
29+
- `threadsPerWarp_` AKA threadsPerWarp, use the name threadsPerWarp_ to avoid conflicting
30+
with the `getThreadsPerWarp` in interface DistributedLayout. Currently only 16 is supported.
3031

3132
The values of the matrix is distributed across the threads in the subgroup as row-major order.
32-
- If the column size of the matrix is equal to the number of threads in the subgroup, a single value name represents a single rows of the matrix.
33-
- If the column size of the matrix is less than the number of threads in the subgroup, a single value name represents multiple rows of the matrix.
34-
- If the column size of the matrix is larger than the number of the threads in the subgroup, a single row of the matrix requires multiple value name.
33+
- If the column size of the matrix is equal to the number of threads in the subgroup, one scalar represents one row of the matrix in register.
34+
- If the column size of the matrix is less than the number of threads in the subgroup, one scalar represents multiple rows of the matrix in register.
35+
- If the column size of the matrix is larger than the number of the threads in the subgroup, one scalar represents partial row of the matrix in register.
3536

3637
Example 1, the column size of the matrix is 16 and the number of threads in the subgroup is 16.
37-
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and sugGroupSize=16.
38+
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and threadsPerWarp=16.
3839

3940
The layout for A operand:
4041
K = 16 (K = systolic depth * opsPerChan)
@@ -83,7 +84,7 @@ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15
8384
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v
8485

8586
Example 2, the column size of the matrix is 8 and the number of threads in the subgroup is 16.
86-
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and sugGroupSize=16.
87+
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and threadsPerWarp=16.
8788

8889
The layout for A operand:
8990
K = 8 (K = systolic depth * opsPerChan)
@@ -102,7 +103,7 @@ The layouts for B operand is like the one of opsPerChan=2 but the K size is 8.
102103
The layouts for C and D operands are same as the one of opsPerChan=2.
103104

104105
Example 3, the column size of the matrix is 32 and the number of threads in the subgroup is 16.
105-
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and sugGroupSize=16.
106+
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and threadsPerWarp=16.
106107

107108
The layout for A operand:
108109
K = 32 (K = systolic depth * opsPerChan)
@@ -121,49 +122,57 @@ The layouts for B operand is like the one of opsPerChan=2 but the K size is 32.
121122
The layouts for C and D operands are same as the one of opsPerChan=2.
122123

123124
The patterns (illustrated above) repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks
124-
along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing.
125-
126-
Suppose we have a `tt.dot` operation of the block size [64, 128] += [64, 32] * [32, 128] of hf16/bf16.
127-
The `warpsPerCTA` set to [2, 2]. The number of repetitions of the DPAS tile per warp is: A=8, B=8, C,D=16.
128-
The DPAS repetitions are distributed as follows:
129-
130-
warp[:0] warp[:1] warp[:0] warp[:1]
131-
|----^----|----^----|----^----|----^----|
132-
repCluster[1]
133-
<--------->
134-
┌────┬────┬────┬────┬────┬────┬────┬────┐
135-
│R0 │R1 │ │ │R4 │R5 │ │ │
136-
│ │ │ │ │ │ │ │ │
137-
├────┼────┼────┼────┼────┼────┼────┼────┤
138-
│R2 │R3 │ │ │R6 │R7 │ │ │
139-
│ │ │ │ │ │ │ │ │
140-
└────┴────┴────┴────┴────┴────┴────┴────┘
141-
142-
- ^ ┌────┬────┐ ┌────┬────┬────┬────┬────┬────┬────┬────┐
143-
| | │R0 │R2 │ │R0 │R1 │ │ │R4 │R5 │ │ │
144-
| | │ │ │ │ │ │ │ │ │ │ │ │
145-
warp[0:] < repCluster[0] | ]────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
146-
| | │R1 │R3 │ │R2 │R3 │ │ │R6 │R7 │ │ │
147-
| | │ │ │ │ │ │ │ │ │ │ │ │
148-
- v ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
149-
| │ │ │ │ │ │ │ │ │ │ │ │
150-
| │ │ │ │ │ │ │ │ │ │ │ │
151-
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
152-
| │ │ │ │ │ │ │ │ │ │ │ │
153-
| │ │ │ │ │ │ │ │ │ │ │ │
154-
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
155-
| │R4 │R6 │ │R8 │R9 │ │ │R12 │R13 │ │ │
156-
| │ │ │ │ │ │ │ │ │ │ │ │
157-
warp[0:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
158-
| │R5 │R7 │ │R10 │R11 │ │ │R14 │R15 │ │ │
159-
| │ │ │ │ │ │ │ │ │ │ │ │
160-
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
161-
| │ │ │ │ │ │ │ │ │ │ │ │
162-
| │ │ │ │ │ │ │ │ │ │ │ │
163-
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
164-
| │ │ │ │ │ │ │ │ │ │ │ │
165-
| │ │ │ │ │ │ │ │ │ │ │ │
166-
- └────┴────┘ └────┴────┴────┴────┴────┴────┴────┴────┘
125+
along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing.
126+
127+
Suppose we have a `tt.dot` operation of the block size [64, 128] = [64, 32] * [32, 128] of f16/bf16. And its input tensor layout is defined as follows:
128+
```
129+
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 2]}>
130+
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#dpas, kWidth=2}>
131+
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
132+
133+
%d = tt.dot %a, %b, %c : tensor<64x32xf16, #dot_operand_a> * tensor<32x128xf16, #dot_operand_b> -> tensor<64x128xf32, #dpas>
134+
```
135+
The semantic of this `tt.dot` includes GEMM tiling configuration as:
136+
137+
warp[:0] warp[:1] warp[:0] warp[:1]
138+
|----^----|----^----|----^----|----^----|
139+
repCluster[1]
140+
<--------->
141+
┌────┬────┬────┬────┬────┬────┬────┬────┐
142+
│W0R0│W0R1│W1R0│W1R1│W0R4│W0R5│W1R4│W1R5│
143+
│W2R0│W2R1│W3R0│W3R1│W2R4│W2R5│W3R4│W3R5│
144+
warpPerCTA = [[W0, W1], ├────┼────┼────┼────┼────┼────┼────┼────┤
145+
[W2, W3]] │W0R2│W0R3│W1R2│W1R3│W0R6│W0R7│W1R6│W1R7│
146+
│W2R2│W2R3│W3R2│W3R3│W2R6│W2R7│W3R6│W3R7│
147+
└────┴────┴────┴────┴────┴────┴────┴────┘
148+
149+
150+
- ^ ┌────┬────┐ ┌────┬────┬────┬────┬────┬────┬────┬────┐
151+
| | │W0R0│W0R2│ │W0R0│W0R1│W1R0│W1R1│W0R4│W0R5│W1R4│W1R5│
152+
| | │W1R0│W1R2│ │ │ │ │ │ │ │ │ │
153+
warp[0:] < repCluster[0] | ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
154+
| | │W0R1│W0R3│ │W0R2│W0R3│W1R2│W1R3│W0R6│W0R7│W1R6│W1R7│
155+
| | │W1R1│W1R3│ │ │ │ │ │ │ │ │ │
156+
- v ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
157+
| │W2R0│W2R2│ │W2R0│W2R1│W3R0│W3R1│W2R4│W2R5│W3R4│W3R5│
158+
| │W3R0│W3R2│ │ │ │ │ │ │ │ │ │
159+
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
160+
| │W2R1│W2R1│ │W2R2│W2R3│W3R2│W3R3│W2R6│W2R7│W3R6│W3R7│
161+
| │W3R1│W3R1│ │ │ │ │ │ │ │ │ │
162+
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
163+
| │W0R4│W0R6│ │W0R8│W0R9│W1R8│W1R9│W0 │W0 │W1 │W1 │
164+
| │W1R4│W1R6│ │ │ │ │ │R12 │R13 │R12 │R13 │
165+
warp[0:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
166+
| │W0R5│W0R7│ │W0 │W0 │W1 │W1 │W0 │W0 │W1 │W1 │
167+
| │W1R5│W1R7│ │R10 │R11 │R10 │R11 │R14 │R15 │R14 │R15 │
168+
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
169+
| │W2R4│W2R6│ │W2R8│W2R9│W3R8│W3R8│W2 │W2 │W3 │W3 │
170+
| │W3R4│W3R6│ │ │ │ │ │R12 │R13 │R12 │R13 │
171+
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
172+
| │W2R5│W2R7│ │W2 │W2 │W3 │W3 │W2 │W2 │W3 │W3 │
173+
| │W3R5│W3R7│ │R10 │R11 │R10 │R10 │R14 │R15 │R14 │R15 │
174+
- └────┴────┘ └────┴────┴────┴────┴────┴────┴────┴────┘
175+
167176

168177
}];
169178

@@ -175,7 +184,7 @@ The DPAS repetitions are distributed as follows:
175184
"unsigned":$opsPerChannel,
176185
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
177186
ArrayRefParameter<"unsigned">:$repCluster,
178-
"unsigned":$subGroupSize
187+
"unsigned":$threadsPerWarp_
179188
);
180189

181190
let extraClassDeclaration = extraDistributedDeclaration # [{

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ SmallVector<unsigned> DpasEncodingAttr::getShapeC() const {
134134
SmallVector<unsigned> DpasEncodingAttr::getSizePerThread() const {
135135
size_t rank = getWarpsPerCTA().size();
136136
SmallVector<unsigned> res(rank, 1);
137-
unsigned threadsPerWarp = getSubGroupSize();
137+
unsigned threadsPerWarp = getThreadsPerWarp_();
138138
auto shapeC = getDPASInstShapeC();
139139
unsigned elemsNum = product<unsigned>(shapeC);
140140
unsigned elemsPerThread = elemsNum / threadsPerWarp;
@@ -260,7 +260,7 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand(
260260
ArrayRef<int64_t> shape, mlir::Type eltTy, int kWidth, int opIdx) const {
261261
auto shapePerCTA = getShapePerCTA(*this, shape);
262262
auto rep = getDPASRepetitions(shapePerCTA, opIdx);
263-
auto threadsPerWar = getSubGroupSize();
263+
auto threadsPerWar = getThreadsPerWarp_();
264264
size_t rank = shape.size();
265265
if (opIdx == 0) {
266266
auto shapeA = getShapeA();
@@ -296,7 +296,7 @@ SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
296296
size_t rank = getWarpsPerCTA().size();
297297
SmallVector<unsigned> res(rank, 1);
298298
auto executionSize = getExecutionSize();
299-
auto subGroupSize = getSubGroupSize();
299+
auto subGroupSize = getThreadsPerWarp_();
300300
if (subGroupSize < executionSize) {
301301
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not be "
302302
"smaller than the execution size");
@@ -340,7 +340,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
340340
assert((rank == 2 || rank == 3) && "unexpected rank number for Dpas layout");
341341
if (opIdx == 0) {
342342
SmallVector<unsigned> shapeA = getDPASInstShapeA();
343-
unsigned subGroupSize = getSubGroupSize();
343+
unsigned subGroupSize = getThreadsPerWarp_();
344344
unsigned opsPerChannel = getOpsPerChannel();
345345

346346
// pack the value to i16 for scalar bit width <=16.
@@ -359,7 +359,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
359359

360360
if (opIdx == 1) {
361361
auto shapeB = getShapeB();
362-
auto subGroupSize = getSubGroupSize();
362+
auto subGroupSize = getThreadsPerWarp_();
363363
auto executionSize = getExecutionSize();
364364
if (subGroupSize < executionSize) {
365365
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not "
@@ -394,7 +394,7 @@ SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
394394
assert(rank == 2 || rank == 3);
395395
SmallVector<unsigned> contigPerThread(rank, 1);
396396

397-
unsigned threadsPerWarp = getSubGroupSize();
397+
unsigned threadsPerWarp = getThreadsPerWarp_();
398398
auto instShapeC = getDPASInstShapeC();
399399
// The software vectorization vectorized the value as C array: int a[N] -> int
400400
// a[N][threadsPerWarp]
@@ -506,7 +506,7 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const {
506506
<< "systolicDepth = " << getSystolicDepth() << ", "
507507
<< "executionSize = " << getExecutionSize() << ", "
508508
<< "opsPerChan = " << getOpsPerChannel() << ", "
509-
<< "threadsPerWarp = " << getSubGroupSize() << ", "
509+
<< "threadsPerWarp = " << getThreadsPerWarp_() << ", "
510510
<< "warpsPerCTA = [" << llvm::ArrayRef<unsigned>(warpsPerCTA) << "], "
511511
<< "repCluster = [" << repCluster << "], "
512512
<< "A = [" << rA << "], "

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ struct ConvertLayoutOpConversion
334334
size_t totalElems = elems.size();
335335
auto numElemsPerOperand =
336336
product<unsigned>(dpasLayout.getDPASInstShapeC()) /
337-
dpasLayout.getSubGroupSize();
337+
product<unsigned>(dpasLayout.getThreadsPerWarp());
338338
Type elemTy =
339339
this->getTypeConverter()->convertType(srcType.getElementType());
340340
VectorType dotOpTy = vec_ty(elemTy, numElemsPerOperand);

third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class DotOpDPASConversionHelper {
3636
Type i16Ty = type::i16Ty(ctx);
3737
Type s32Ty = IntegerType::get(ctx, 32, IntegerType::Signed);
3838

39-
unsigned threadsPerWarp = layout.getSubGroupSize();
39+
unsigned threadsPerWarp = product<unsigned>(layout.getThreadsPerWarp());
4040
unsigned opsPerChannel = layout.getOpsPerChannel();
4141
SmallVector<unsigned> shapeC = layout.getDPASInstShapeC();
4242
unsigned elemNumC = product<unsigned>(shapeC) / threadsPerWarp;

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout,
168168
sizePerThreads[rank - 2] / repCluster[rank - 2],
169169
sizePerThreads[rank - 1] / repCluster[rank - 1]};
170170

171-
unsigned rowsPerElem = dpasLayout.getSubGroupSize() / instShapeC[1];
171+
unsigned rowsPerElem =
172+
product<unsigned>(dpasLayout.getThreadsPerWarp()) / instShapeC[1];
172173
unsigned colsPerElem = 1;
173174

174175
unsigned repNumber = product<unsigned>(repCluster);

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
12+
#include "triton/Dialect/Triton/IR/Utility.h"
1213

1314
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1415

@@ -175,11 +176,12 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
175176

176177
// We want to transpose matrices of N*threads_per_warpxthreads_per_warp
177178
// shape.
179+
unsigned threadsPerWarp = product<unsigned>(encoding.getThreadsPerWarp());
178180
if ( // X axis condition
179-
encoding.getExecutionSize() != encoding.getSubGroupSize() ||
181+
encoding.getExecutionSize() != threadsPerWarp ||
180182
// Y axis conditions
181183
(encoding.getRepeatCount() * encoding.getRepCluster()[0]) %
182-
encoding.getSubGroupSize() !=
184+
threadsPerWarp !=
183185
0)
184186
return failure();
185187

0 commit comments

Comments
 (0)