Skip to content

Commit 7fc674a

Browse files
committed
Update the documents based on review comments.
1 parent eda4c40 commit 7fc674a

File tree

6 files changed

+78
-64
lines changed

6 files changed

+78
-64
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 63 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,21 @@ The encoding is characterized by parameters:
2323
- `repeatCount` which shall be in the range [1, 8]
2424
- `systolicDepth` For PVC/ATSM, the size is 8.
2525
- `executionSize` For PVC, the size is 16. For ATSM, the size is 8.
26-
- `opsPerChannel` 4 for 8 bit scalar type, 2 for 16 bit scalar type, 1 for 32 bit scalar type.
26+
- `opsPerChannel` 4 for 8 bit scalar type of A/B operands of DPAS instruction,
27+
2 for 16 bit scalar type of A/B operands of DPAS instruction,
28+
1 for 32 bit scalar type of A/B operands of DPAS instruction.
2729
- `warpsPerCTA` indicates the distribution of the warps in the block. The order is [1, 0] for rank 2.
2830
- `repCluster` indicates the cluster size of the repetitions of the DPAS tile.
29-
- `sugGroupSize` Currently only sub group size 16 is supported.
31+
- `threadsPerWarp_` AKA threadsPerWarp, use the name threadsPerWarp_ to avoid conflicting
32+
with the `getThreadsPerWarp` in interface DistributedLayout. Currently only 16 is supported.
3033

3134
The values of the matrix is distributed across the threads in the subgroup as row-major order.
32-
- If the column size of the matrix is equal to the number of threads in the subgroup, a single value name represents a single rows of the matrix.
33-
- If the column size of the matrix is less than the number of threads in the subgroup, a single value name represents multiple rows of the matrix.
34-
- If the column size of the matrix is larger than the number of the threads in the subgroup, a single row of the matrix requires multiple value name.
35+
- If the column size of the matrix is equal to the number of threads in the subgroup, one scalar represents one row of the matrix in register.
36+
- If the column size of the matrix is less than the number of threads in the subgroup, one scalar represents multiple rows of the matrix in register.
37+
- If the column size of the matrix is larger than the number of the threads in the subgroup, one scalar represents partial row of the matrix in register.
3538

3639
Example 1, the column size of the matrix is 16 and the number of threads in the subgroup is 16.
37-
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and sugGroupSize=16.
40+
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and threadsPerWarp=16.
3841

3942
The layout for A operand:
4043
K = 16 (K = systolic depth * opsPerChan)
@@ -83,7 +86,7 @@ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15
8386
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v
8487

8588
Example 2, the column size of the matrix is 8 and the number of threads in the subgroup is 16.
86-
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and sugGroupSize=16.
89+
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and threadsPerWarp=16.
8790

8891
The layout for A operand:
8992
K = 8 (K = systolic depth * opsPerChan)
@@ -102,7 +105,7 @@ The layouts for B operand is like the one of opsPerChan=2 but the K size is 8.
102105
The layouts for C and D operands are same as the one of opsPerChan=2.
103106

104107
Example 3, the column size of the matrix is 32 and the number of threads in the subgroup is 16.
105-
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and sugGroupSize=16.
108+
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and threadsPerWarp=16.
106109

107110
The layout for A operand:
108111
K = 32 (K = systolic depth * opsPerChan)
@@ -121,49 +124,57 @@ The layouts for B operand is like the one of opsPerChan=2 but the K size is 32.
121124
The layouts for C and D operands are same as the one of opsPerChan=2.
122125

123126
The patterns (illustrated above) repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks
124-
along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing.
125-
126-
Suppose we have a `tt.dot` operation of the block size [64, 128] += [64, 32] * [32, 128] of hf16/bf16.
127-
The `warpsPerCTA` set to [2, 2]. The number of repetitions of the DPAS tile per warp is: A=8, B=8, C,D=16.
128-
The DPAS repetitions are distributed as follows:
129-
130-
warp[:0] warp[:1] warp[:0] warp[:1]
131-
|----^----|----^----|----^----|----^----|
132-
repCluster[1]
133-
<--------->
134-
┌────┬────┬────┬────┬────┬────┬────┬────┐
135-
│R0 │R1 │ │ │R4 │R5 │ │ │
136-
│ │ │ │ │ │ │ │ │
137-
├────┼────┼────┼────┼────┼────┼────┼────┤
138-
│R2 │R3 │ │ │R6 │R7 │ │ │
139-
│ │ │ │ │ │ │ │ │
140-
└────┴────┴────┴────┴────┴────┴────┴────┘
141-
142-
- ^ ┌────┬────┐ ┌────┬────┬────┬────┬────┬────┬────┬────┐
143-
| | │R0 │R2 │ │R0 │R1 │ │ │R4 │R5 │ │ │
144-
| | │ │ │ │ │ │ │ │ │ │ │ │
145-
warp[0:] < repCluster[0] | ]────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
146-
| | │R1 │R3 │ │R2 │R3 │ │ │R6 │R7 │ │ │
147-
| | │ │ │ │ │ │ │ │ │ │ │ │
148-
- v ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
149-
| │ │ │ │ │ │ │ │ │ │ │ │
150-
| │ │ │ │ │ │ │ │ │ │ │ │
151-
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
152-
| │ │ │ │ │ │ │ │ │ │ │ │
153-
| │ │ │ │ │ │ │ │ │ │ │ │
154-
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
155-
| │R4 │R6 │ │R8 │R9 │ │ │R12 │R13 │ │ │
156-
| │ │ │ │ │ │ │ │ │ │ │ │
157-
warp[0:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
158-
| │R5 │R7 │ │R10 │R11 │ │ │R14 │R15 │ │ │
159-
| │ │ │ │ │ │ │ │ │ │ │ │
160-
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
161-
| │ │ │ │ │ │ │ │ │ │ │ │
162-
| │ │ │ │ │ │ │ │ │ │ │ │
163-
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
164-
| │ │ │ │ │ │ │ │ │ │ │ │
165-
| │ │ │ │ │ │ │ │ │ │ │ │
166-
- └────┴────┘ └────┴────┴────┴────┴────┴────┴────┴────┘
127+
along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing.
128+
129+
Suppose we have a `tt.dot` operation of the block size [64, 128] = [64, 32] * [32, 128] of f16/bf16. And its input tensor layout is defined as follows:
130+
```
131+
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 2]}>
132+
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#dpas, kWidth=2}>
133+
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
134+
135+
%d = tt.dot %a, %b, %c : tensor<64x32xf16, #dot_operand_a> * tensor<32x128xf16, #dot_operand_b> -> tensor<64x128xf32, #dpas>
136+
```
137+
The semantic of this `tt.dot` includes GEMM tiling configuration as:
138+
139+
warp[:0] warp[:1] warp[:0] warp[:1]
140+
|----^----|----^----|----^----|----^----|
141+
repCluster[1]
142+
<--------->
143+
┌────┬────┬────┬────┬────┬────┬────┬────┐
144+
│W0R0│W0R1│W1R0│W1R1│W0R4│W0R5│W1R4│W1R5│
145+
│W2R0│W2R1│W3R0│W3R1│W2R4│W2R5│W3R4│W3R5│
146+
warpPerCTA = [[W0, W1], ├────┼────┼────┼────┼────┼────┼────┼────┤
147+
[W2, W3]] │W0R2│W0R3│W1R2│W1R3│W0R6│W0R7│W1R6│W1R7│
148+
│W2R2│W2R3│W3R2│W3R3│W2R6│W2R7│W3R6│W3R7│
149+
└────┴────┴────┴────┴────┴────┴────┴────┘
150+
151+
152+
- ^ ┌────┬────┐ ┌────┬────┬────┬────┬────┬────┬────┬────┐
153+
| | │W0R0│W0R2│ │W0R0│W0R1│W1R0│W1R1│W0R4│W0R5│W1R4│W1R5│
154+
| | │W1R0│W1R2│ │ │ │ │ │ │ │ │ │
155+
warp[0:] < repCluster[0] | ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
156+
| | │W0R1│W0R3│ │W0R2│W0R3│W1R2│W1R3│W0R6│W0R7│W1R6│W1R7│
157+
| | │W1R1│W1R3│ │ │ │ │ │ │ │ │ │
158+
- v ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
159+
| │W2R0│W2R2│ │W2R0│W2R1│W3R0│W3R1│W2R4│W2R5│W3R4│W3R5│
160+
| │W3R0│W3R2│ │ │ │ │ │ │ │ │ │
161+
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
162+
| │W2R1│W2R1│ │W2R2│W2R3│W3R2│W3R3│W2R6│W2R7│W3R6│W3R7│
163+
| │W3R1│W3R1│ │ │ │ │ │ │ │ │ │
164+
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
165+
| │W0R4│W0R6│ │W0R8│W0R9│W1R8│W1R9│W0 │W0 │W1 │W1 │
166+
| │W1R4│W1R6│ │ │ │ │ │R12 │R13 │R12 │R13 │
167+
warp[0:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
168+
| │W0R5│W0R7│ │W0 │W0 │W1 │W1 │W0 │W0 │W1 │W1 │
169+
| │W1R5│W1R7│ │R10 │R11 │R10 │R11 │R14 │R15 │R14 │R15 │
170+
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
171+
| │W2R4│W2R6│ │W2R8│W2R9│W3R8│W3R8│W2 │W2 │W3 │W3 │
172+
| │W3R4│W3R6│ │ │ │ │ │R12 │R13 │R12 │R13 │
173+
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
174+
| │W2R5│W2R7│ │W2 │W2 │W3 │W3 │W2 │W2 │W3 │W3 │
175+
| │W3R5│W3R7│ │R10 │R11 │R10 │R10 │R14 │R15 │R14 │R15 │
176+
- └────┴────┘ └────┴────┴────┴────┴────┴────┴────┴────┘
177+
167178

168179
}];
169180

@@ -175,7 +186,7 @@ The DPAS repetitions are distributed as follows:
175186
"unsigned":$opsPerChannel,
176187
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
177188
ArrayRefParameter<"unsigned">:$repCluster,
178-
"unsigned":$subGroupSize
189+
"unsigned":$threadsPerWarp_
179190
);
180191

181192
let extraClassDeclaration = extraDistributedDeclaration # [{

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ SmallVector<unsigned> DpasEncodingAttr::getShapeC() const {
134134
SmallVector<unsigned> DpasEncodingAttr::getSizePerThread() const {
135135
size_t rank = getWarpsPerCTA().size();
136136
SmallVector<unsigned> res(rank, 1);
137-
unsigned threadsPerWarp = getSubGroupSize();
137+
unsigned threadsPerWarp = getThreadsPerWarp_();
138138
auto shapeC = getDPASInstShapeC();
139139
unsigned elemsNum = product<unsigned>(shapeC);
140140
unsigned elemsPerThread = elemsNum / threadsPerWarp;
@@ -260,7 +260,7 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand(
260260
ArrayRef<int64_t> shape, mlir::Type eltTy, int kWidth, int opIdx) const {
261261
auto shapePerCTA = getShapePerCTA(*this, shape);
262262
auto rep = getDPASRepetitions(shapePerCTA, opIdx);
263-
auto threadsPerWar = getSubGroupSize();
263+
auto threadsPerWar = getThreadsPerWarp_();
264264
size_t rank = shape.size();
265265
if (opIdx == 0) {
266266
auto shapeA = getShapeA();
@@ -296,7 +296,7 @@ SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
296296
size_t rank = getWarpsPerCTA().size();
297297
SmallVector<unsigned> res(rank, 1);
298298
auto executionSize = getExecutionSize();
299-
auto subGroupSize = getSubGroupSize();
299+
auto subGroupSize = getThreadsPerWarp_();
300300
if (subGroupSize < executionSize) {
301301
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not be "
302302
"smaller than the execution size");
@@ -340,7 +340,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
340340
assert((rank == 2 || rank == 3) && "unexpected rank number for Dpas layout");
341341
if (opIdx == 0) {
342342
SmallVector<unsigned> shapeA = getDPASInstShapeA();
343-
unsigned subGroupSize = getSubGroupSize();
343+
unsigned subGroupSize = getThreadsPerWarp_();
344344
unsigned opsPerChannel = getOpsPerChannel();
345345

346346
// pack the value to i16 for scalar bit width <=16.
@@ -359,7 +359,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
359359

360360
if (opIdx == 1) {
361361
auto shapeB = getShapeB();
362-
auto subGroupSize = getSubGroupSize();
362+
auto subGroupSize = getThreadsPerWarp_();
363363
auto executionSize = getExecutionSize();
364364
if (subGroupSize < executionSize) {
365365
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not "
@@ -394,7 +394,7 @@ SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
394394
assert(rank == 2 || rank == 3);
395395
SmallVector<unsigned> contigPerThread(rank, 1);
396396

397-
unsigned threadsPerWarp = getSubGroupSize();
397+
unsigned threadsPerWarp = getThreadsPerWarp_();
398398
auto instShapeC = getDPASInstShapeC();
399399
// The software vectorization vectorized the value as C array: int a[N] -> int
400400
// a[N][threadsPerWarp]
@@ -506,7 +506,7 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const {
506506
<< "systolicDepth = " << getSystolicDepth() << ", "
507507
<< "executionSize = " << getExecutionSize() << ", "
508508
<< "opsPerChan = " << getOpsPerChannel() << ", "
509-
<< "threadsPerWarp = " << getSubGroupSize() << ", "
509+
<< "threadsPerWarp = " << getThreadsPerWarp_() << ", "
510510
<< "warpsPerCTA = [" << llvm::ArrayRef<unsigned>(warpsPerCTA) << "], "
511511
<< "repCluster = [" << repCluster << "], "
512512
<< "A = [" << rA << "], "

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ struct ConvertLayoutOpConversion
334334
size_t totalElems = elems.size();
335335
auto numElemsPerOperand =
336336
product<unsigned>(dpasLayout.getDPASInstShapeC()) /
337-
dpasLayout.getSubGroupSize();
337+
product<unsigned>(dpasLayout.getThreadsPerWarp());
338338
Type elemTy =
339339
this->getTypeConverter()->convertType(srcType.getElementType());
340340
VectorType dotOpTy = vec_ty(elemTy, numElemsPerOperand);

third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class DotOpDPASConversionHelper {
3636
Type i16Ty = type::i16Ty(ctx);
3737
Type s32Ty = IntegerType::get(ctx, 32, IntegerType::Signed);
3838

39-
unsigned threadsPerWarp = layout.getSubGroupSize();
39+
unsigned threadsPerWarp = product<unsigned>(layout.getThreadsPerWarp());
4040
unsigned opsPerChannel = layout.getOpsPerChannel();
4141
SmallVector<unsigned> shapeC = layout.getDPASInstShapeC();
4242
unsigned elemNumC = product<unsigned>(shapeC) / threadsPerWarp;

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout,
168168
sizePerThreads[rank - 2] / repCluster[rank - 2],
169169
sizePerThreads[rank - 1] / repCluster[rank - 1]};
170170

171-
unsigned rowsPerElem = dpasLayout.getSubGroupSize() / instShapeC[1];
171+
unsigned rowsPerElem =
172+
product<unsigned>(dpasLayout.getThreadsPerWarp()) / instShapeC[1];
172173
unsigned colsPerElem = 1;
173174

174175
unsigned repNumber = product<unsigned>(repCluster);

0 commit comments

Comments
 (0)