11#include " triton/Dialect/Triton/IR/Dialect.h"
22
3+ #include < cstdint>
34#include < numeric>
45
56#include " intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h"
1213
1314#include " intel/include/Dialect/TritonIntelGPU/IR/Dialect.cpp.inc"
1415
16+ #include " llvm/ADT/SmallVector.h"
1517#include " llvm/ADT/TypeSwitch.h"
18+ #include " llvm/Support/ErrorHandling.h"
1619
1720using namespace mlir ;
1821using namespace mlir ::triton;
@@ -104,49 +107,75 @@ SmallVector<unsigned> DpasEncodingAttr::getDPASInstShapeC() const {
104107SmallVector<unsigned > DpasEncodingAttr::getShapeA () const {
105108 auto shapeA = getDPASInstShapeA ();
106109 auto repCluster = getRepCluster ();
107- return {shapeA[0 ] * repCluster[0 ], shapeA[1 ]};
110+ size_t rank = repCluster.size ();
111+ SmallVector<unsigned > resShape (rank, 1 );
112+ resShape[rank - 2 ] = shapeA[0 ] * repCluster[rank - 2 ];
113+ resShape[rank - 1 ] = shapeA[1 ];
114+ return resShape;
108115}
109116
110117SmallVector<unsigned > DpasEncodingAttr::getShapeB () const {
111118 auto shapeB = getDPASInstShapeB ();
112119 auto repCluster = getRepCluster ();
113- return {shapeB[0 ], shapeB[1 ] * repCluster[1 ]};
120+ size_t rank = repCluster.size ();
121+ SmallVector<unsigned > resShape (rank, 1 );
122+ resShape[rank - 2 ] = shapeB[0 ];
123+ resShape[rank - 1 ] = shapeB[1 ] * repCluster[rank - 1 ];
124+ return resShape;
114125}
115126
116127SmallVector<unsigned > DpasEncodingAttr::getShapeC () const {
117128 auto shapeC = getDPASInstShapeC ();
118129 auto repCluster = getRepCluster ();
119- return {shapeC[0 ] * repCluster[0 ], shapeC[1 ] * repCluster[1 ]};
130+ size_t rank = repCluster.size ();
131+ SmallVector<unsigned > resShape (rank, 1 );
132+ resShape[rank - 2 ] = shapeC[0 ] * repCluster[rank - 2 ];
133+ resShape[rank - 1 ] = shapeC[1 ] * repCluster[rank - 1 ];
134+ return resShape;
120135}
121136
122137SmallVector<unsigned > DpasEncodingAttr::getSizePerThread () const {
138+ size_t rank = getWarpsPerCTA ().size ();
139+ SmallVector<unsigned > res (rank, 1 );
123140 unsigned threadsPerWarp = getSubGroupSize ();
124141 auto shapeC = getDPASInstShapeC ();
125142 unsigned elemsNum = product<unsigned >(shapeC);
126143 unsigned elemsPerThread = elemsNum / threadsPerWarp;
127144 auto repCluster = getRepCluster ();
128145 // The Value is shard to lanes to threads per DPAS instruction.
129- return {elemsPerThread * repCluster[0 ], repCluster[1 ]};
146+ res[rank - 2 ] = elemsPerThread * repCluster[rank - 2 ];
147+ res[rank - 1 ] = repCluster[rank - 1 ];
148+ return res;
130149}
131150
132151SmallVector<unsigned >
133152DpasEncodingAttr::getShapePerCTATile (ArrayRef<int64_t > tensorShape) const {
134153 auto shapeC = getShapeC ();
135- return {shapeC[0 ] * getWarpsPerCTA ()[0 ], shapeC[1 ] * getWarpsPerCTA ()[1 ]};
154+ SmallVector<unsigned > warpsPerCTA = getWarpsPerCTA ();
155+ size_t rank = shapeC.size ();
156+ assert (rank == shapeC.size () &&
157+ " ShapeC and WarpsPerCTA must have the same rank" );
158+ SmallVector<unsigned > shapePerCTATile (rank);
159+ for (size_t i = 0 ; i < rank; ++i) {
160+ shapePerCTATile[i] = shapeC[i] * warpsPerCTA[i];
161+ }
162+ return shapePerCTATile;
136163}
137164
138165SmallVector<unsigned >
139166DpasEncodingAttr::getElemsPerThread (ArrayRef<int64_t > shape, Type eltTy) const {
140167 size_t rank = shape.size ();
141- assert (rank == 2 && " Unexpected rank of mma layout" );
168+ assert (( rank == 2 || rank == 3 ) && " Unexpected rank of mma layout" );
142169
143170 SmallVector<unsigned > elemsPerThread (rank);
144171 auto shapePerCTATile = getShapePerCTATile (shape);
145- unsigned tilesRow = ceil<unsigned >(shape[0 ], shapePerCTATile[0 ]);
146- unsigned tilesCol = ceil<unsigned >(shape[1 ], shapePerCTATile[1 ]);
172+ unsigned tilesRow =
173+ ceil<unsigned >(shape[rank - 2 ], shapePerCTATile[rank - 2 ]);
174+ unsigned tilesCol =
175+ ceil<unsigned >(shape[rank - 1 ], shapePerCTATile[rank - 1 ]);
147176 auto sizePerThread = getSizePerThread ();
148- elemsPerThread[0 ] = sizePerThread[0 ] * tilesRow;
149- elemsPerThread[1 ] = sizePerThread[1 ] * tilesCol;
177+ elemsPerThread[rank - 2 ] = sizePerThread[rank - 2 ] * tilesRow;
178+ elemsPerThread[rank - 1 ] = sizePerThread[rank - 1 ] * tilesCol;
150179
151180 return elemsPerThread;
152181}
@@ -174,48 +203,73 @@ SmallVector<unsigned> DpasEncodingAttr::getCTAsPerCGA() const {
174203SmallVector<int64_t >
175204DpasEncodingAttr::getDPASRepetitions (ArrayRef<int64_t > shape, int opIdx) const {
176205 auto warpsPerCTA = getWarpsPerCTA ();
206+ int rank = shape.size ();
207+ SmallVector<int64_t > res (rank);
177208 if (opIdx == 0 ) {
178209 auto shapePerWarp = getShapeA ();
179- return {std::max<int64_t >(1 , shape[0 ] / (shapePerWarp[0 ] * warpsPerCTA[0 ])),
180- std::max<int64_t >(1 , shape[1 ] / shapePerWarp[1 ])};
210+ if (rank == 3 )
211+ res[0 ] =
212+ std::max<int64_t >(1 , shape[0 ] / (shapePerWarp[0 ] * warpsPerCTA[0 ]));
213+ res[rank - 2 ] = std::max<int64_t >(
214+ 1 , shape[rank - 2 ] / (shapePerWarp[rank - 2 ] * warpsPerCTA[rank - 2 ]));
215+ res[rank - 1 ] =
216+ std::max<int64_t >(1 , shape[rank - 1 ] / shapePerWarp[rank - 1 ]);
181217 } else if (opIdx == 1 ) {
182218 auto shapePerWarp = getShapeB ();
183- return {
184- std::max<int64_t >(1 , shape[0 ] / shapePerWarp[0 ]),
185- std::max<int64_t >(1 , shape[1 ] / (shapePerWarp[1 ] * warpsPerCTA[1 ]))};
219+ if (rank == 3 )
220+ res[0 ] =
221+ std::max<int64_t >(1 , shape[0 ] / (shapePerWarp[0 ] * warpsPerCTA[0 ]));
222+ res[rank - 2 ] =
223+ std::max<int64_t >(1 , shape[rank - 2 ] / shapePerWarp[rank - 2 ]);
224+ res[rank - 1 ] = std::max<int64_t >(
225+ 1 , shape[rank - 1 ] / (shapePerWarp[rank - 1 ] * warpsPerCTA[rank - 1 ]));
186226 } else {
187227 assert (opIdx == 2 && " Unexpected operand id (valid ids are 0, 1 or 2)" );
188228 auto shapePerWarp = getShapeC ();
189- return {
190- std::max<int64_t >(1 , mlir::ceil<unsigned >(
191- shape[0 ], shapePerWarp[0 ] * warpsPerCTA[0 ])),
192- std::max<int64_t >(1 , mlir::ceil<unsigned >(
193- shape[1 ], shapePerWarp[1 ] * warpsPerCTA[1 ]))};
229+ if (rank == 3 )
230+ res[0 ] =
231+ std::max<int64_t >(1 , shape[0 ] / (shapePerWarp[0 ] * warpsPerCTA[0 ]));
232+ res[rank - 2 ] = std::max<int64_t >(
233+ 1 , shape[rank - 2 ] / (shapePerWarp[rank - 2 ] * warpsPerCTA[rank - 2 ]));
234+ res[rank - 1 ] = std::max<int64_t >(
235+ 1 , shape[rank - 1 ] / (shapePerWarp[rank - 1 ] * warpsPerCTA[rank - 1 ]));
194236 }
237+ return res;
195238}
196239
197240unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperands (
198241 ArrayRef<int64_t > shape, mlir::Type eltTy, int kWidth , int opIdx) const {
199242 auto shapePerCTA = getShapePerCTA (*this , shape);
200243 auto rep = getDPASRepetitions (shapePerCTA, opIdx);
201244 auto threadsPerWar = getSubGroupSize ();
245+ int rank = shape.size ();
202246 if (opIdx == 0 ) {
203247 auto shapeA = getShapeA ();
204248 auto totalElem = product<unsigned >(shapeA);
205249 // dpas operands scalar are evenly sharded to each work item.
206- return (totalElem / threadsPerWar) * rep[0 ] * rep[1 ];
207- } else { // if (opIdx == 1)
250+ return (totalElem / threadsPerWar) * product<int64_t >(rep);
251+ }
252+ if (opIdx == 1 ) {
208253 auto shapeB = getShapeB ();
209254 auto totalElem = product<unsigned >(shapeB);
210255 // dpas operands scalar are evenly sharded to each work item.
211- return (totalElem / threadsPerWar) * rep[ 0 ] * rep[ 1 ] ;
256+ return (totalElem / threadsPerWar) * product< int64_t >( rep) ;
212257 }
258+ llvm_unreachable (" DpasEncodingAttr opIdx must be 0 or 1" );
213259}
214260
215- SmallVector<unsigned > DpasEncodingAttr::getWarpOrder () const { return {1 , 0 }; }
261+ SmallVector<unsigned > DpasEncodingAttr::getWarpOrder () const {
262+ size_t rank = getWarpsPerCTA ().size ();
263+ SmallVector<unsigned > order (rank);
264+ std::iota (order.rbegin (), order.rend (), 0 );
265+ return order;
266+ }
216267
217268SmallVector<unsigned > DpasEncodingAttr::getThreadOrder () const {
218- return {1 , 0 };
269+ size_t rank = getWarpsPerCTA ().size ();
270+ SmallVector<unsigned > order (rank);
271+ std::iota (order.rbegin (), order.rend (), 0 );
272+ return order;
219273}
220274
221275SmallVector<unsigned > DpasEncodingAttr::getWarpsPerCTA () const {
@@ -224,33 +278,48 @@ SmallVector<unsigned> DpasEncodingAttr::getWarpsPerCTA() const {
224278}
225279
226280SmallVector<unsigned > DpasEncodingAttr::getThreadsPerWarp () const {
281+ size_t rank = getWarpsPerCTA ().size ();
282+ SmallVector<unsigned > res (rank, 1 );
227283 auto executionSize = getExecutionSize ();
228284 auto subGroupSize = getSubGroupSize ();
229285 if (subGroupSize < executionSize) {
230286 llvm::report_fatal_error (" DpasEncodingAttr sub-group size could not be "
231287 " smaller than the execution size" );
232288 }
233- return {subGroupSize / executionSize, executionSize};
289+ res[rank - 2 ] = subGroupSize / executionSize;
290+ res[rank - 1 ] = executionSize;
291+ return res;
234292}
235293
236294SmallVector<unsigned >
237295DpasEncodingAttr::getShapePerCTATileForDotOperands (ArrayRef<int64_t > shape,
238296 int opIdx) const {
239297 auto parentShapePerCTATile = getShapePerCTATile (shape);
240- auto threadsPerWarp = getThreadsPerWarp ();
298+ // auto threadsPerWarp = getThreadsPerWarp();
299+ size_t rank = parentShapePerCTATile.size ();
241300 if (opIdx == 0 ) {
242301 auto shapeA = getShapeA ();
243- return {parentShapePerCTATile[0 ], shapeA[1 ]};
302+ if (rank == 2 )
303+ return {parentShapePerCTATile[0 ], shapeA[1 ]};
304+ else
305+ return {parentShapePerCTATile[0 ], parentShapePerCTATile[rank - 2 ],
306+ shapeA[rank - 1 ]};
244307 } else if (opIdx == 1 ) {
245308 auto shapeB = getShapeB ();
246- return {shapeB[0 ], parentShapePerCTATile[1 ]};
309+ if (rank == 2 )
310+ return {shapeB[0 ], parentShapePerCTATile[1 ]};
311+ else
312+ return {parentShapePerCTATile[0 ], shapeB[rank - 2 ],
313+ parentShapePerCTATile[rank - 1 ]};
247314 } else {
248315 llvm::report_fatal_error (" DotOperandEncodingAttr opIdx must be 0 or 1" );
249316 }
250317}
251318
252319SmallVector<unsigned >
253320DpasEncodingAttr::getSizePerThreadForOperands (unsigned opIdx) const {
321+ ArrayRef<unsigned > repCluster = getRepCluster ();
322+ size_t rank = repCluster.size ();
254323 if (opIdx == 0 ) {
255324 SmallVector<unsigned > shapeA = getDPASInstShapeA ();
256325 unsigned subGroupSize = getSubGroupSize ();
@@ -267,8 +336,7 @@ DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
267336 " be smaller than the threads required per row." );
268337 }
269338 unsigned rowsPerWarp = mlir::ceil<unsigned >(subGroupSize, packedColNum);
270- auto repCluster = getRepCluster ();
271- return {shapeA[0 ] / rowsPerWarp * repCluster[0 ], packedOpsPerLane};
339+ return {shapeA[0 ] / rowsPerWarp * repCluster[rank - 2 ], packedOpsPerLane};
272340 } else if (opIdx == 1 ) {
273341 auto shapeB = getShapeB ();
274342 auto subGroupSize = getSubGroupSize ();
@@ -279,9 +347,8 @@ DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
279347 }
280348 SmallVector<unsigned , 2 > threadsPerWarp = {subGroupSize / executionSize,
281349 executionSize};
282- auto repCluster = getRepCluster ();
283- return {shapeB[0 ] / threadsPerWarp[0 ],
284- shapeB[1 ] / threadsPerWarp[1 ] * repCluster[1 ]};
350+ return {shapeB[rank - 2 ] / threadsPerWarp[0 ],
351+ shapeB[rank - 1 ] / threadsPerWarp[1 ] * repCluster[rank - 1 ]};
285352 } else {
286353 llvm::report_fatal_error (" DotOperandEncodingAttr opIdx must be 0 or 1" );
287354 return {};
@@ -293,20 +360,31 @@ SmallVector<unsigned> DpasEncodingAttr::getElemsPerThreadForOperands(
293360 SmallVector<unsigned > sizePerThread = getSizePerThreadForOperands (opIdx);
294361 SmallVector<int64_t > repetitions = getDPASRepetitions (shape, opIdx);
295362
296- return {static_cast <unsigned >(sizePerThread[0 ] * repetitions[0 ]),
297- static_cast <unsigned >(sizePerThread[1 ] * repetitions[1 ])};
363+ size_t rank = shape.size ();
364+ SmallVector<unsigned > elemsPerThread (rank);
365+ if (rank == 3 )
366+ elemsPerThread[0 ] = repetitions[0 ];
367+ elemsPerThread[rank - 2 ] = sizePerThread[rank - 2 ] * repetitions[rank - 2 ];
368+ elemsPerThread[rank - 1 ] = sizePerThread[rank - 1 ] * repetitions[rank - 1 ];
369+
370+ return elemsPerThread;
298371};
299372
300373SmallVector<unsigned > DpasEncodingAttr::getContigPerThread () {
374+ size_t rank = getWarpsPerCTA ().size ();
375+ assert (rank == 2 || rank == 3 );
376+ SmallVector<unsigned > contigPerThread (rank, 1 );
377+
301378 unsigned threadsPerWarp = getSubGroupSize ();
302379 auto shapeC = getDPASInstShapeC ();
303380 // The software vectorization vectorized the value as C array: int a[N] -> int
304381 // a[N][threadsPerWarp]
305382 if (threadsPerWarp > shapeC[1 ]) {
306- return { 1 , 1 } ;
383+ return contigPerThread ;
307384 } else if (threadsPerWarp == shapeC[1 ]) {
308385 auto repCluster = getRepCluster ();
309- return {shapeC[0 ] * repCluster[0 ], 1 };
386+ contigPerThread[rank - 2 ] = shapeC[0 ] * repCluster[rank - 2 ];
387+ return contigPerThread;
310388 } else {
311389 // threadsPerWarp < shapeC[1]
312390 llvm::report_fatal_error (" DpasEncodingAttr sub-group size could not "
@@ -333,8 +411,8 @@ LogicalResult DpasEncodingAttr::verify(
333411 return emitError () << " systolicDepth must be 8, but was:" << opsPerChan;
334412 }
335413
336- if (repCluster.size () != 2 ) {
337- return emitError () << " expected rank 2 of repCluster, but the rank is:"
414+ if (!( repCluster.size () == 2 || repCluster. size () == 3 ) ) {
415+ return emitError () << " expected rank 2 or 3 of repCluster, but the rank is:"
338416 << repCluster.size ();
339417 }
340418
0 commit comments