Skip to content

Commit 9ec46fe

Browse files
[NFI]: Introduce a C++ typed enum to represent DPAS operand index (#2996)
Instead of using an unsigned to represent the operand indexes of `tt.dot` we can use a strongly typed C++ enum class. --------- Signed-off-by: Tiotto, Ettore <[email protected]> Co-authored-by: Whitney Tsang <[email protected]>
1 parent 7cb5b5d commit 9ec46fe

File tree

10 files changed

+247
-205
lines changed

10 files changed

+247
-205
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,41 @@ along the row (resp. col) dimension.
7474
);
7575

7676
let extraClassDeclaration = extraDistributedDeclaration # [{
77+
enum class OpIdx : unsigned {
78+
OperandA = 0u,
79+
OperandB = 1u,
80+
OperandC = 2u
81+
};
82+
7783
SmallVector<unsigned> getDPASInstShapeA() const;
7884
SmallVector<unsigned> getDPASInstShapeB() const;
7985
SmallVector<unsigned> getDPASInstShapeC() const;
8086
SmallVector<unsigned> getShapeA() const;
8187
SmallVector<unsigned> getShapeB() const;
8288
SmallVector<unsigned> getShapeC() const;
83-
SmallVector<int64_t> getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const;
84-
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth,unsigned opIdx) const;
85-
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const;
86-
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
87-
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
89+
90+
SmallVector<int64_t> getDPASRepetitions(ArrayRef<int64_t> shape, OpIdx opIdx) const;
91+
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, OpIdx opIdx) const;
92+
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, OpIdx opIdx) const;
93+
SmallVector<unsigned> getRepOrderForOperand(OpIdx opIdx) const;
94+
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, OpIdx opIdx) const;
95+
96+
// Forwarder functions for casting unsigned to OpIdx.
97+
SmallVector<int64_t> getDPASRepetitions(ArrayRef<int64_t> shape, unsigned opIdx) const {
98+
return getDPASRepetitions(shape, static_cast<OpIdx>(opIdx));
99+
}
100+
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
101+
return getSizePerThreadForOperand(kWidth, static_cast<OpIdx>(opIdx));
102+
}
103+
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const {
104+
return getElemsPerThreadForOperands(shape, eltTy, static_cast<OpIdx>(opIdx));
105+
}
106+
SmallVector<unsigned> getRepOrderForOperand(unsigned opIdx) const {
107+
return getRepOrderForOperand(static_cast<OpIdx>(opIdx));
108+
}
109+
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, unsigned opIdx) const {
110+
return getTotalElemsPerThreadForOperand(shape, eltTy, kWidth, static_cast<OpIdx>(opIdx));
111+
}
88112

89113
bool supportReduction() const {
90114
return true;

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 74 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#include "triton/Dialect/Triton/IR/Dialect.h"
22

3-
#include <numeric>
4-
53
#include "intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h"
64
#include "mlir/IR/DialectImplementation.h"
75
#include "mlir/IR/OpImplementation.h"
@@ -12,7 +10,9 @@
1210

1311
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.cpp.inc"
1412

13+
#include "llvm/ADT/SmallVector.h"
1514
#include "llvm/ADT/TypeSwitch.h"
15+
#include "llvm/Support/ErrorHandling.h"
1616

1717
using namespace mlir;
1818
using namespace mlir::triton;
@@ -102,8 +102,8 @@ SmallVector<unsigned> DpasEncodingAttr::getDPASInstShapeC() const {
102102
};
103103

104104
SmallVector<unsigned> DpasEncodingAttr::getShapeA() const {
105-
auto instShapeA = getDPASInstShapeA();
106-
auto repCluster = getRepCluster();
105+
SmallVector<unsigned> instShapeA = getDPASInstShapeA();
106+
ArrayRef<unsigned> repCluster = getRepCluster();
107107
size_t rank = repCluster.size();
108108
SmallVector<unsigned> resShape(rank, 1);
109109
resShape[rank - 2] = instShapeA[0] * repCluster[rank - 2];
@@ -112,8 +112,8 @@ SmallVector<unsigned> DpasEncodingAttr::getShapeA() const {
112112
}
113113

114114
SmallVector<unsigned> DpasEncodingAttr::getShapeB() const {
115-
auto instShapeB = getDPASInstShapeB();
116-
auto repCluster = getRepCluster();
115+
SmallVector<unsigned> instShapeB = getDPASInstShapeB();
116+
ArrayRef<unsigned> repCluster = getRepCluster();
117117
size_t rank = repCluster.size();
118118
SmallVector<unsigned> resShape(rank, 1);
119119
resShape[rank - 2] = instShapeB[0];
@@ -122,8 +122,8 @@ SmallVector<unsigned> DpasEncodingAttr::getShapeB() const {
122122
}
123123

124124
SmallVector<unsigned> DpasEncodingAttr::getShapeC() const {
125-
auto instShapeC = getDPASInstShapeC();
126-
auto repCluster = getRepCluster();
125+
SmallVector<unsigned> instShapeC = getDPASInstShapeC();
126+
ArrayRef<unsigned> repCluster = getRepCluster();
127127
size_t rank = repCluster.size();
128128
SmallVector<unsigned> resShape(rank, 1);
129129
resShape[rank - 2] = instShapeC[0] * repCluster[rank - 2];
@@ -135,7 +135,7 @@ SmallVector<unsigned> DpasEncodingAttr::getSizePerThread() const {
135135
size_t rank = getWarpsPerCTA().size();
136136
SmallVector<unsigned> res(rank, 1);
137137
unsigned threadsPerWarp = getSubGroupSize();
138-
auto shapeC = getDPASInstShapeC();
138+
SmallVector<unsigned> shapeC = getDPASInstShapeC();
139139
unsigned elemsNum = product<unsigned>(shapeC);
140140
unsigned elemsPerThread = elemsNum / threadsPerWarp;
141141
auto repCluster = getRepCluster();
@@ -151,9 +151,10 @@ SmallVector<unsigned> DpasEncodingAttr::getRepOrder() const {
151151
llvm::report_fatal_error("NYI. DpasEncodingAttr::getRepOrder");
152152
}
153153

154-
SmallVector<unsigned> DpasEncodingAttr::getRepOrderForOperand(int opIdx) const {
155-
auto rank = getWarpsPerCTA().size();
156-
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
154+
SmallVector<unsigned>
155+
DpasEncodingAttr::getRepOrderForOperand(OpIdx opIdx) const {
156+
size_t rank = getWarpsPerCTA().size();
157+
return getOrderForDotOperand(unsigned(opIdx), rank, /*kMajor*/ true);
157158
}
158159

159160
SmallVector<unsigned>
@@ -162,8 +163,7 @@ DpasEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
162163
assert((rank == 2 || rank == 3) && "Unexpected rank of mma layout");
163164

164165
SmallVector<unsigned> elemsPerThread(rank, 1);
165-
166-
auto shapeC = getShapeC();
166+
SmallVector<unsigned> shapeC = getShapeC();
167167
SmallVector<unsigned> warpsPerCTA = getWarpsPerCTA();
168168
SmallVector<unsigned> shapePerCTATile(rank);
169169
llvm::transform(
@@ -174,7 +174,7 @@ DpasEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
174174
ceil<unsigned>(shape[rank - 2], shapePerCTATile[rank - 2]);
175175
unsigned tilesCol =
176176
ceil<unsigned>(shape[rank - 1], shapePerCTATile[rank - 1]);
177-
auto sizePerThread = getSizePerThread();
177+
SmallVector<unsigned> sizePerThread = getSizePerThread();
178178
if (rank == 3)
179179
elemsPerThread[0] =
180180
sizePerThread[0] * ceil<unsigned>(shape[0], shapePerCTATile[0]);
@@ -208,14 +208,16 @@ SmallVector<unsigned> DpasEncodingAttr::getCTAsPerCGA() const {
208208
}
209209

210210
SmallVector<int64_t>
211-
DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
211+
DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape,
212+
OpIdx opIdx) const {
212213
// Always return a 3D shape repetitions for the ease of value handling, same
213214
// to mma.
214-
auto warpsPerCTA = getWarpsPerCTA();
215-
int rank = shape.size();
215+
SmallVector<unsigned> warpsPerCTA = getWarpsPerCTA();
216+
size_t rank = shape.size();
216217
SmallVector<int64_t> rep(3, 1);
217-
if (opIdx == 0) {
218-
auto shapePerWarp = getShapeA();
218+
switch (opIdx) {
219+
case OpIdx::OperandA: {
220+
SmallVector<unsigned> shapePerWarp = getShapeA();
219221
int64_t numRepBatch =
220222
rank == 3 ? std::max<int64_t>(1, shape[0] /
221223
(shapePerWarp[0] * warpsPerCTA[0]))
@@ -224,10 +226,9 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
224226
std::max<int64_t>(1, shape[rank - 2] / (shapePerWarp[rank - 2] *
225227
warpsPerCTA[rank - 2])),
226228
std::max<int64_t>(1, shape[rank - 1] / shapePerWarp[rank - 1])};
227-
}
228-
229-
if (opIdx == 1) {
230-
auto shapePerWarp = getShapeB();
229+
} break;
230+
case OpIdx::OperandB: {
231+
SmallVector<unsigned> shapePerWarp = getShapeB();
231232
int64_t numRepBatch =
232233
rank == 3 ? std::max<int64_t>(1, shape[0] /
233234
(shapePerWarp[0] * warpsPerCTA[0]))
@@ -236,9 +237,9 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
236237
std::max<int64_t>(1, shape[rank - 2] / shapePerWarp[rank - 2]),
237238
std::max<int64_t>(1, shape[rank - 1] / (shapePerWarp[rank - 1] *
238239
warpsPerCTA[rank - 1]))};
240+
} break;
239241
}
240242

241-
assert(opIdx == 2 && "Unexpected operand id (valid ids are 0, 1 or 2)");
242243
auto shapePerWarp = getShapeC();
243244
int64_t numRepBatch =
244245
rank == 3
@@ -252,24 +253,27 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
252253
}
253254

254255
unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand(
255-
ArrayRef<int64_t> shape, mlir::Type eltTy, int kWidth, int opIdx) const {
256-
auto shapePerCTA = getShapePerCTA(*this, shape);
257-
auto rep = getDPASRepetitions(shapePerCTA, opIdx);
258-
auto threadsPerWar = getSubGroupSize();
256+
ArrayRef<int64_t> shape, mlir::Type eltTy, int kWidth, OpIdx opIdx) const {
257+
SmallVector<int64_t> shapePerCTA = getShapePerCTA(*this, shape);
258+
SmallVector<int64_t> rep = getDPASRepetitions(shapePerCTA, opIdx);
259+
unsigned threadsPerWar = getSubGroupSize();
259260
size_t rank = shape.size();
260-
if (opIdx == 0) {
261-
auto shapeA = getShapeA();
261+
262+
switch (opIdx) {
263+
case OpIdx::OperandA: {
264+
SmallVector<unsigned> shapeA = getShapeA();
262265
auto totalElem = product<unsigned>(shapeA);
263266
// dpas operands scalar are evenly sharded to each work item.
264267
return (totalElem / threadsPerWar) * product<int64_t>(rep);
265-
}
266-
if (opIdx == 1) {
267-
auto shapeB = getShapeB();
268+
} break;
269+
case OpIdx::OperandB: {
270+
SmallVector<unsigned> shapeB = getShapeB();
268271
auto totalElem = product<unsigned>(shapeB);
269272
// dpas operands scalar are evenly sharded to each work item.
270273
return (totalElem / threadsPerWar) * product<int64_t>(rep);
274+
} break;
271275
}
272-
llvm_unreachable("DpasEncodingAttr opIdx must be 0 or 1");
276+
llvm_unreachable("unexpected opIdx");
273277
}
274278

275279
SmallVector<unsigned> DpasEncodingAttr::getWarpOrder() const {
@@ -290,8 +294,8 @@ SmallVector<unsigned> DpasEncodingAttr::getWarpsPerCTA() const {
290294
SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
291295
size_t rank = getWarpsPerCTA().size();
292296
SmallVector<unsigned> res(rank, 1);
293-
auto executionSize = getExecutionSize();
294-
auto subGroupSize = getSubGroupSize();
297+
unsigned executionSize = getExecutionSize();
298+
unsigned subGroupSize = getSubGroupSize();
295299
if (subGroupSize < executionSize) {
296300
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not be "
297301
"smaller than the execution size");
@@ -302,11 +306,13 @@ SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
302306
}
303307

304308
SmallVector<unsigned>
305-
DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
309+
DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, OpIdx opIdx) const {
306310
ArrayRef<unsigned> repCluster = getRepCluster();
307311
size_t rank = repCluster.size();
308312
assert((rank == 2 || rank == 3) && "unexpected rank number for Dpas layout");
309-
if (opIdx == 0) {
313+
314+
switch (opIdx) {
315+
case OpIdx::OperandA: {
310316
SmallVector<unsigned> shapeA = getDPASInstShapeA();
311317
unsigned subGroupSize = getSubGroupSize();
312318
unsigned opsPerChannel = getOpsPerChannel();
@@ -323,12 +329,11 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
323329
}
324330
unsigned rowsPerWarp = mlir::ceil<unsigned>(subGroupSize, packedColNum);
325331
return {shapeA[0] / rowsPerWarp * repCluster[rank - 2], packedOpsPerLane};
326-
}
327-
328-
if (opIdx == 1) {
329-
auto shapeB = getShapeB();
330-
auto subGroupSize = getSubGroupSize();
331-
auto executionSize = getExecutionSize();
332+
} break;
333+
case OpIdx::OperandB: {
334+
SmallVector<unsigned> shapeB = getShapeB();
335+
unsigned subGroupSize = getSubGroupSize();
336+
unsigned executionSize = getExecutionSize();
332337
if (subGroupSize < executionSize) {
333338
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not "
334339
"be smaller than the execution size");
@@ -337,13 +342,14 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
337342
executionSize};
338343
return {shapeB[rank - 2] / threadsPerWarp[0],
339344
shapeB[rank - 1] / threadsPerWarp[1] * repCluster[rank - 1]};
345+
} break;
340346
}
341-
342-
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
347+
llvm_unreachable("unexpected opIdx");
343348
}
344349

345-
SmallVector<unsigned> DpasEncodingAttr::getElemsPerThreadForOperands(
346-
ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const {
350+
SmallVector<unsigned>
351+
DpasEncodingAttr::getElemsPerThreadForOperands(ArrayRef<int64_t> shape,
352+
Type eltTy, OpIdx opIdx) const {
347353
SmallVector<unsigned> sizePerThread = getSizePerThreadForOperand(0, opIdx);
348354
SmallVector<int64_t> repetitions = getDPASRepetitions(shape, opIdx);
349355

@@ -363,15 +369,15 @@ SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() const {
363369
SmallVector<unsigned> contigPerThread(rank, 1);
364370

365371
unsigned threadsPerWarp = getSubGroupSize();
366-
auto instShapeC = getDPASInstShapeC();
367-
// The software vectorization vectorized the value as C array: int a[N] -> int
368-
// a[N][threadsPerWarp]
372+
SmallVector<unsigned> instShapeC = getDPASInstShapeC();
373+
// The software vectorization vectorized the value as C array: int a[N] ->
374+
// int a[N][threadsPerWarp]
369375
if (threadsPerWarp > instShapeC[1]) {
370376
return contigPerThread;
371377
}
372378

373379
if (threadsPerWarp == instShapeC[1]) {
374-
auto repCluster = getRepCluster();
380+
ArrayRef<unsigned> repCluster = getRepCluster();
375381
contigPerThread[rank - 2] = instShapeC[0] * repCluster[rank - 2];
376382
return contigPerThread;
377383
}
@@ -485,14 +491,14 @@ Attribute DpasEncodingAttr::parse(AsmParser &parser, Type type) {
485491
}
486492

487493
void DpasEncodingAttr::print(AsmPrinter &printer) const {
488-
auto shapeA = getShapeA();
489-
llvm::ArrayRef<unsigned> rA = shapeA;
490-
auto shapeB = getShapeB();
491-
llvm::ArrayRef<unsigned> rB = shapeB;
492-
auto shapeC = getShapeC();
493-
llvm::ArrayRef<unsigned> rC = shapeC;
494-
auto warpsPerCTA = getWarpsPerCTA();
495-
auto repCluster = getRepCluster();
494+
SmallVector<unsigned> shapeA = getShapeA();
495+
ArrayRef<unsigned> rA = shapeA;
496+
SmallVector<unsigned> shapeB = getShapeB();
497+
ArrayRef<unsigned> rB = shapeB;
498+
SmallVector<unsigned> shapeC = getShapeC();
499+
ArrayRef<unsigned> rC = shapeC;
500+
SmallVector<unsigned> warpsPerCTA = getWarpsPerCTA();
501+
ArrayRef<unsigned> repCluster = getRepCluster();
496502
printer << "<{" << "repeatCount = " << getRepeatCount() << ", "
497503
<< "systolicDepth = " << getSystolicDepth() << ", "
498504
<< "executionSize = " << getExecutionSize() << ", "
@@ -515,8 +521,8 @@ DpasEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
515521
SmallVector<unsigned>
516522
WarpEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
517523
size_t rank = shape.size();
518-
auto sizePerThread = getSizePerThread();
519-
auto threadsPerWarp = getThreadsPerWarp();
524+
ArrayRef<unsigned> sizePerThread = getSizePerThread();
525+
ArrayRef<unsigned> threadsPerWarp = getThreadsPerWarp();
520526
assert(rank == sizePerThread.size() &&
521527
"unexpected rank in WarpEncodingAttr::getElemsPerThread");
522528
SmallVector<unsigned> elemsPerThread(rank);
@@ -571,12 +577,11 @@ Attribute WarpEncodingAttr::parse(AsmParser &parser, Type type) {
571577
}
572578

573579
void WarpEncodingAttr::print(mlir::AsmPrinter &printer) const {
574-
auto threadsPerWarp = getThreadsPerWarp();
575-
auto sizePerThread = getSizePerThread();
576-
printer << "<{" << "sizePerThread = ["
577-
<< llvm::ArrayRef<unsigned>(sizePerThread) << "]"
578-
<< ", threadsPerWarp = [" << llvm::ArrayRef<unsigned>(threadsPerWarp)
579-
<< "]" << ", order = [" << getOrder() << "]" << "}>";
580+
ArrayRef<unsigned> threadsPerWarp = getThreadsPerWarp();
581+
ArrayRef<unsigned> sizePerThread = getSizePerThread();
582+
printer << "<{" << "sizePerThread = [" << sizePerThread << "]"
583+
<< ", threadsPerWarp = [" << threadsPerWarp << "]" << ", order = ["
584+
<< getOrder() << "]" << "}>";
580585
}
581586

582587
//===----------------------------------------------------------------------===//
@@ -676,7 +681,6 @@ struct TritonIntelGPUInferLayoutInterface
676681
//===----------------------------------------------------------------------===//
677682

678683
void TritonIntelGPUDialect::initialize() {
679-
680684
addAttributes<
681685
#define GET_ATTRDEF_LIST
682686
#include "intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.cpp.inc"

0 commit comments

Comments
 (0)