11#include " triton/Dialect/Triton/IR/Dialect.h"
22
3- #include < numeric>
4-
53#include " intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h"
64#include " mlir/IR/DialectImplementation.h"
75#include " mlir/IR/OpImplementation.h"
1210
1311#include " intel/include/Dialect/TritonIntelGPU/IR/Dialect.cpp.inc"
1412
13+ #include " llvm/ADT/SmallVector.h"
1514#include " llvm/ADT/TypeSwitch.h"
15+ #include " llvm/Support/ErrorHandling.h"
1616
1717using namespace mlir ;
1818using namespace mlir ::triton;
@@ -102,8 +102,8 @@ SmallVector<unsigned> DpasEncodingAttr::getDPASInstShapeC() const {
102102};
103103
104104SmallVector<unsigned > DpasEncodingAttr::getShapeA () const {
105- auto instShapeA = getDPASInstShapeA ();
106- auto repCluster = getRepCluster ();
105+ SmallVector< unsigned > instShapeA = getDPASInstShapeA ();
106+ ArrayRef< unsigned > repCluster = getRepCluster ();
107107 size_t rank = repCluster.size ();
108108 SmallVector<unsigned > resShape (rank, 1 );
109109 resShape[rank - 2 ] = instShapeA[0 ] * repCluster[rank - 2 ];
@@ -112,8 +112,8 @@ SmallVector<unsigned> DpasEncodingAttr::getShapeA() const {
112112}
113113
114114SmallVector<unsigned > DpasEncodingAttr::getShapeB () const {
115- auto instShapeB = getDPASInstShapeB ();
116- auto repCluster = getRepCluster ();
115+ SmallVector< unsigned > instShapeB = getDPASInstShapeB ();
116+ ArrayRef< unsigned > repCluster = getRepCluster ();
117117 size_t rank = repCluster.size ();
118118 SmallVector<unsigned > resShape (rank, 1 );
119119 resShape[rank - 2 ] = instShapeB[0 ];
@@ -122,8 +122,8 @@ SmallVector<unsigned> DpasEncodingAttr::getShapeB() const {
122122}
123123
124124SmallVector<unsigned > DpasEncodingAttr::getShapeC () const {
125- auto instShapeC = getDPASInstShapeC ();
126- auto repCluster = getRepCluster ();
125+ SmallVector< unsigned > instShapeC = getDPASInstShapeC ();
126+ ArrayRef< unsigned > repCluster = getRepCluster ();
127127 size_t rank = repCluster.size ();
128128 SmallVector<unsigned > resShape (rank, 1 );
129129 resShape[rank - 2 ] = instShapeC[0 ] * repCluster[rank - 2 ];
@@ -135,7 +135,7 @@ SmallVector<unsigned> DpasEncodingAttr::getSizePerThread() const {
135135 size_t rank = getWarpsPerCTA ().size ();
136136 SmallVector<unsigned > res (rank, 1 );
137137 unsigned threadsPerWarp = getSubGroupSize ();
138- auto shapeC = getDPASInstShapeC ();
138+ SmallVector< unsigned > shapeC = getDPASInstShapeC ();
139139 unsigned elemsNum = product<unsigned >(shapeC);
140140 unsigned elemsPerThread = elemsNum / threadsPerWarp;
141141 auto repCluster = getRepCluster ();
@@ -151,9 +151,10 @@ SmallVector<unsigned> DpasEncodingAttr::getRepOrder() const {
151151 llvm::report_fatal_error (" NYI. DpasEncodingAttr::getRepOrder" );
152152}
153153
154- SmallVector<unsigned > DpasEncodingAttr::getRepOrderForOperand (int opIdx) const {
155- auto rank = getWarpsPerCTA ().size ();
156- return getOrderForDotOperand (opIdx, rank, /* kMajor*/ true );
154+ SmallVector<unsigned >
155+ DpasEncodingAttr::getRepOrderForOperand (OpIdx opIdx) const {
156+ size_t rank = getWarpsPerCTA ().size ();
157+ return getOrderForDotOperand (unsigned (opIdx), rank, /* kMajor*/ true );
157158}
158159
159160SmallVector<unsigned >
@@ -162,8 +163,7 @@ DpasEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
162163 assert ((rank == 2 || rank == 3 ) && " Unexpected rank of mma layout" );
163164
164165 SmallVector<unsigned > elemsPerThread (rank, 1 );
165-
166- auto shapeC = getShapeC ();
166+ SmallVector<unsigned > shapeC = getShapeC ();
167167 SmallVector<unsigned > warpsPerCTA = getWarpsPerCTA ();
168168 SmallVector<unsigned > shapePerCTATile (rank);
169169 llvm::transform (
@@ -174,7 +174,7 @@ DpasEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
174174 ceil<unsigned >(shape[rank - 2 ], shapePerCTATile[rank - 2 ]);
175175 unsigned tilesCol =
176176 ceil<unsigned >(shape[rank - 1 ], shapePerCTATile[rank - 1 ]);
177- auto sizePerThread = getSizePerThread ();
177+ SmallVector< unsigned > sizePerThread = getSizePerThread ();
178178 if (rank == 3 )
179179 elemsPerThread[0 ] =
180180 sizePerThread[0 ] * ceil<unsigned >(shape[0 ], shapePerCTATile[0 ]);
@@ -208,14 +208,16 @@ SmallVector<unsigned> DpasEncodingAttr::getCTAsPerCGA() const {
208208}
209209
210210SmallVector<int64_t >
211- DpasEncodingAttr::getDPASRepetitions (ArrayRef<int64_t > shape, int opIdx) const {
211+ DpasEncodingAttr::getDPASRepetitions (ArrayRef<int64_t > shape,
212+ OpIdx opIdx) const {
212213 // Always return a 3D shape repetitions for the ease of value handling, same
213214 // to mma.
214- auto warpsPerCTA = getWarpsPerCTA ();
215- int rank = shape.size ();
215+ SmallVector< unsigned > warpsPerCTA = getWarpsPerCTA ();
216+ size_t rank = shape.size ();
216217 SmallVector<int64_t > rep (3 , 1 );
217- if (opIdx == 0 ) {
218- auto shapePerWarp = getShapeA ();
218+ switch (opIdx) {
219+ case OpIdx::OperandA: {
220+ SmallVector<unsigned > shapePerWarp = getShapeA ();
219221 int64_t numRepBatch =
220222 rank == 3 ? std::max<int64_t >(1 , shape[0 ] /
221223 (shapePerWarp[0 ] * warpsPerCTA[0 ]))
@@ -224,10 +226,9 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
224226 std::max<int64_t >(1 , shape[rank - 2 ] / (shapePerWarp[rank - 2 ] *
225227 warpsPerCTA[rank - 2 ])),
226228 std::max<int64_t >(1 , shape[rank - 1 ] / shapePerWarp[rank - 1 ])};
227- }
228-
229- if (opIdx == 1 ) {
230- auto shapePerWarp = getShapeB ();
229+ } break ;
230+ case OpIdx::OperandB: {
231+ SmallVector<unsigned > shapePerWarp = getShapeB ();
231232 int64_t numRepBatch =
232233 rank == 3 ? std::max<int64_t >(1 , shape[0 ] /
233234 (shapePerWarp[0 ] * warpsPerCTA[0 ]))
@@ -236,9 +237,9 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
236237 std::max<int64_t >(1 , shape[rank - 2 ] / shapePerWarp[rank - 2 ]),
237238 std::max<int64_t >(1 , shape[rank - 1 ] / (shapePerWarp[rank - 1 ] *
238239 warpsPerCTA[rank - 1 ]))};
240+ } break ;
239241 }
240242
241- assert (opIdx == 2 && " Unexpected operand id (valid ids are 0, 1 or 2)" );
242243 auto shapePerWarp = getShapeC ();
243244 int64_t numRepBatch =
244245 rank == 3
@@ -252,24 +253,27 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
252253}
253254
254255unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand (
255- ArrayRef<int64_t > shape, mlir::Type eltTy, int kWidth , int opIdx) const {
256- auto shapePerCTA = getShapePerCTA (*this , shape);
257- auto rep = getDPASRepetitions (shapePerCTA, opIdx);
258- auto threadsPerWar = getSubGroupSize ();
256+ ArrayRef<int64_t > shape, mlir::Type eltTy, int kWidth , OpIdx opIdx) const {
257+ SmallVector< int64_t > shapePerCTA = getShapePerCTA (*this , shape);
258+ SmallVector< int64_t > rep = getDPASRepetitions (shapePerCTA, opIdx);
259+ unsigned threadsPerWar = getSubGroupSize ();
259260 size_t rank = shape.size ();
260- if (opIdx == 0 ) {
261- auto shapeA = getShapeA ();
261+
262+ switch (opIdx) {
263+ case OpIdx::OperandA: {
264+ SmallVector<unsigned > shapeA = getShapeA ();
262265 auto totalElem = product<unsigned >(shapeA);
263266 // dpas operands scalar are evenly sharded to each work item.
264267 return (totalElem / threadsPerWar) * product<int64_t >(rep);
265- }
266- if (opIdx == 1 ) {
267- auto shapeB = getShapeB ();
268+ } break ;
269+ case OpIdx::OperandB: {
270+ SmallVector< unsigned > shapeB = getShapeB ();
268271 auto totalElem = product<unsigned >(shapeB);
269272 // dpas operands scalar are evenly sharded to each work item.
270273 return (totalElem / threadsPerWar) * product<int64_t >(rep);
274+ } break ;
271275 }
272- llvm_unreachable (" DpasEncodingAttr opIdx must be 0 or 1 " );
276+ llvm_unreachable (" unexpected opIdx" );
273277}
274278
275279SmallVector<unsigned > DpasEncodingAttr::getWarpOrder () const {
@@ -290,8 +294,8 @@ SmallVector<unsigned> DpasEncodingAttr::getWarpsPerCTA() const {
290294SmallVector<unsigned > DpasEncodingAttr::getThreadsPerWarp () const {
291295 size_t rank = getWarpsPerCTA ().size ();
292296 SmallVector<unsigned > res (rank, 1 );
293- auto executionSize = getExecutionSize ();
294- auto subGroupSize = getSubGroupSize ();
297+ unsigned executionSize = getExecutionSize ();
298+ unsigned subGroupSize = getSubGroupSize ();
295299 if (subGroupSize < executionSize) {
296300 llvm::report_fatal_error (" DpasEncodingAttr sub-group size could not be "
297301 " smaller than the execution size" );
@@ -302,11 +306,13 @@ SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
302306}
303307
304308SmallVector<unsigned >
305- DpasEncodingAttr::getSizePerThreadForOperand (int kWidth , unsigned opIdx) const {
309+ DpasEncodingAttr::getSizePerThreadForOperand (int kWidth , OpIdx opIdx) const {
306310 ArrayRef<unsigned > repCluster = getRepCluster ();
307311 size_t rank = repCluster.size ();
308312 assert ((rank == 2 || rank == 3 ) && " unexpected rank number for Dpas layout" );
309- if (opIdx == 0 ) {
313+
314+ switch (opIdx) {
315+ case OpIdx::OperandA: {
310316 SmallVector<unsigned > shapeA = getDPASInstShapeA ();
311317 unsigned subGroupSize = getSubGroupSize ();
312318 unsigned opsPerChannel = getOpsPerChannel ();
@@ -323,12 +329,11 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
323329 }
324330 unsigned rowsPerWarp = mlir::ceil<unsigned >(subGroupSize, packedColNum);
325331 return {shapeA[0 ] / rowsPerWarp * repCluster[rank - 2 ], packedOpsPerLane};
326- }
327-
328- if (opIdx == 1 ) {
329- auto shapeB = getShapeB ();
330- auto subGroupSize = getSubGroupSize ();
331- auto executionSize = getExecutionSize ();
332+ } break ;
333+ case OpIdx::OperandB: {
334+ SmallVector<unsigned > shapeB = getShapeB ();
335+ unsigned subGroupSize = getSubGroupSize ();
336+ unsigned executionSize = getExecutionSize ();
332337 if (subGroupSize < executionSize) {
333338 llvm::report_fatal_error (" DpasEncodingAttr sub-group size could not "
334339 " be smaller than the execution size" );
@@ -337,13 +342,14 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
337342 executionSize};
338343 return {shapeB[rank - 2 ] / threadsPerWarp[0 ],
339344 shapeB[rank - 1 ] / threadsPerWarp[1 ] * repCluster[rank - 1 ]};
345+ } break ;
340346 }
341-
342- llvm::report_fatal_error (" DotOperandEncodingAttr opIdx must be 0 or 1" );
347+ llvm_unreachable (" unexpected opIdx" );
343348}
344349
345- SmallVector<unsigned > DpasEncodingAttr::getElemsPerThreadForOperands (
346- ArrayRef<int64_t > shape, Type eltTy, unsigned opIdx) const {
350+ SmallVector<unsigned >
351+ DpasEncodingAttr::getElemsPerThreadForOperands (ArrayRef<int64_t > shape,
352+ Type eltTy, OpIdx opIdx) const {
347353 SmallVector<unsigned > sizePerThread = getSizePerThreadForOperand (0 , opIdx);
348354 SmallVector<int64_t > repetitions = getDPASRepetitions (shape, opIdx);
349355
@@ -363,15 +369,15 @@ SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() const {
363369 SmallVector<unsigned > contigPerThread (rank, 1 );
364370
365371 unsigned threadsPerWarp = getSubGroupSize ();
366- auto instShapeC = getDPASInstShapeC ();
367- // The software vectorization vectorized the value as C array: int a[N] -> int
368- // a[N][threadsPerWarp]
372+ SmallVector< unsigned > instShapeC = getDPASInstShapeC ();
373+ // The software vectorization vectorized the value as C array: int a[N] ->
374+ // int a[N][threadsPerWarp]
369375 if (threadsPerWarp > instShapeC[1 ]) {
370376 return contigPerThread;
371377 }
372378
373379 if (threadsPerWarp == instShapeC[1 ]) {
374- auto repCluster = getRepCluster ();
380+ ArrayRef< unsigned > repCluster = getRepCluster ();
375381 contigPerThread[rank - 2 ] = instShapeC[0 ] * repCluster[rank - 2 ];
376382 return contigPerThread;
377383 }
@@ -485,14 +491,14 @@ Attribute DpasEncodingAttr::parse(AsmParser &parser, Type type) {
485491}
486492
487493void DpasEncodingAttr::print (AsmPrinter &printer) const {
488- auto shapeA = getShapeA ();
489- llvm:: ArrayRef<unsigned > rA = shapeA;
490- auto shapeB = getShapeB ();
491- llvm:: ArrayRef<unsigned > rB = shapeB;
492- auto shapeC = getShapeC ();
493- llvm:: ArrayRef<unsigned > rC = shapeC;
494- auto warpsPerCTA = getWarpsPerCTA ();
495- auto repCluster = getRepCluster ();
494+ SmallVector< unsigned > shapeA = getShapeA ();
495+ ArrayRef<unsigned > rA = shapeA;
496+ SmallVector< unsigned > shapeB = getShapeB ();
497+ ArrayRef<unsigned > rB = shapeB;
498+ SmallVector< unsigned > shapeC = getShapeC ();
499+ ArrayRef<unsigned > rC = shapeC;
500+ SmallVector< unsigned > warpsPerCTA = getWarpsPerCTA ();
501+ ArrayRef< unsigned > repCluster = getRepCluster ();
496502 printer << " <{" << " repeatCount = " << getRepeatCount () << " , "
497503 << " systolicDepth = " << getSystolicDepth () << " , "
498504 << " executionSize = " << getExecutionSize () << " , "
@@ -515,8 +521,8 @@ DpasEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
515521SmallVector<unsigned >
516522WarpEncodingAttr::getElemsPerThread (ArrayRef<int64_t > shape, Type eltTy) const {
517523 size_t rank = shape.size ();
518- auto sizePerThread = getSizePerThread ();
519- auto threadsPerWarp = getThreadsPerWarp ();
524+ ArrayRef< unsigned > sizePerThread = getSizePerThread ();
525+ ArrayRef< unsigned > threadsPerWarp = getThreadsPerWarp ();
520526 assert (rank == sizePerThread.size () &&
521527 " unexpected rank in WarpEncodingAttr::getElemsPerThread" );
522528 SmallVector<unsigned > elemsPerThread (rank);
@@ -571,12 +577,11 @@ Attribute WarpEncodingAttr::parse(AsmParser &parser, Type type) {
571577}
572578
573579void WarpEncodingAttr::print (mlir::AsmPrinter &printer) const {
574- auto threadsPerWarp = getThreadsPerWarp ();
575- auto sizePerThread = getSizePerThread ();
576- printer << " <{" << " sizePerThread = ["
577- << llvm::ArrayRef<unsigned >(sizePerThread) << " ]"
578- << " , threadsPerWarp = [" << llvm::ArrayRef<unsigned >(threadsPerWarp)
579- << " ]" << " , order = [" << getOrder () << " ]" << " }>" ;
580+ ArrayRef<unsigned > threadsPerWarp = getThreadsPerWarp ();
581+ ArrayRef<unsigned > sizePerThread = getSizePerThread ();
582+ printer << " <{" << " sizePerThread = [" << sizePerThread << " ]"
583+ << " , threadsPerWarp = [" << threadsPerWarp << " ]" << " , order = ["
584+ << getOrder () << " ]" << " }>" ;
580585}
581586
582587// ===----------------------------------------------------------------------===//
@@ -676,7 +681,6 @@ struct TritonIntelGPUInferLayoutInterface
676681// ===----------------------------------------------------------------------===//
677682
678683void TritonIntelGPUDialect::initialize () {
679-
680684 addAttributes<
681685#define GET_ATTRDEF_LIST
682686#include " intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.cpp.inc"
0 commit comments