@@ -153,9 +153,9 @@ DpasEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
153153 SmallVector<unsigned > warpsPerCTA = getWarpsPerCTA ();
154154 size_t rank = shapeC.size ();
155155 SmallVector<unsigned > shapePerCTATile (rank);
156- for ( size_t i = 0 ; i < rank; ++i) {
157- shapePerCTATile[i] = shapeC[i] * warpsPerCTA[i];
158- }
156+ llvm::transform (
157+ llvm::zip_equal ( shapeC, warpsPerCTA), shapePerCTATile. begin (),
158+ []( auto entry) { return std::get< 0 >(entry) * std::get< 1 >(entry); });
159159 return shapePerCTATile;
160160}
161161
@@ -220,7 +220,9 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
220220 std::max<int64_t >(1 , shape[rank - 2 ] / (shapePerWarp[rank - 2 ] *
221221 warpsPerCTA[rank - 2 ])),
222222 std::max<int64_t >(1 , shape[rank - 1 ] / shapePerWarp[rank - 1 ])};
223- } else if (opIdx == 1 ) {
223+ }
224+
225+ if (opIdx == 1 ) {
224226 auto shapePerWarp = getShapeB ();
225227 int64_t numRepBatch =
226228 rank == 3 ? std::max<int64_t >(1 , shape[0 ] /
@@ -230,28 +232,27 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
230232 std::max<int64_t >(1 , shape[rank - 2 ] / shapePerWarp[rank - 2 ]),
231233 std::max<int64_t >(1 , shape[rank - 1 ] / (shapePerWarp[rank - 1 ] *
232234 warpsPerCTA[rank - 1 ]))};
233- } else {
234- assert (opIdx == 2 && " Unexpected operand id (valid ids are 0, 1 or 2)" );
235- auto shapePerWarp = getShapeC ();
236- int64_t numRepBatch =
237- rank == 3 ? std::max<int64_t >(1 , shape[0 ] /
238- (shapePerWarp[0 ] * warpsPerCTA[0 ]))
239- : 1 ;
240- return {numRepBatch,
241- std::max<int64_t >(1 , shape[rank - 2 ] / (shapePerWarp[rank - 2 ] *
242- warpsPerCTA[rank - 2 ])),
243- std::max<int64_t >(1 , shape[rank - 1 ] / (shapePerWarp[rank - 1 ] *
244- warpsPerCTA[rank - 1 ]))};
245235 }
246- return rep;
236+
237+ assert (opIdx == 2 && " Unexpected operand id (valid ids are 0, 1 or 2)" );
238+ auto shapePerWarp = getShapeC ();
239+ int64_t numRepBatch =
240+ rank == 3
241+ ? std::max<int64_t >(1 , shape[0 ] / (shapePerWarp[0 ] * warpsPerCTA[0 ]))
242+ : 1 ;
243+ return {numRepBatch,
244+ std::max<int64_t >(1 , shape[rank - 2 ] / (shapePerWarp[rank - 2 ] *
245+ warpsPerCTA[rank - 2 ])),
246+ std::max<int64_t >(1 , shape[rank - 1 ] / (shapePerWarp[rank - 1 ] *
247+ warpsPerCTA[rank - 1 ]))};
247248}
248249
249250unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperands (
250251 ArrayRef<int64_t > shape, mlir::Type eltTy, int kWidth , int opIdx) const {
251252 auto shapePerCTA = getShapePerCTA (*this , shape);
252253 auto rep = getDPASRepetitions (shapePerCTA, opIdx);
253254 auto threadsPerWar = getSubGroupSize ();
254- int rank = shape.size ();
255+ size_t rank = shape.size ();
255256 if (opIdx == 0 ) {
256257 auto shapeA = getShapeA ();
257258 auto totalElem = product<unsigned >(shapeA);
@@ -269,16 +270,12 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperands(
269270
270271SmallVector<unsigned > DpasEncodingAttr::getWarpOrder () const {
271272 size_t rank = getWarpsPerCTA ().size ();
272- SmallVector<unsigned > order (rank);
273- std::iota (order.rbegin (), order.rend (), 0 );
274- return order;
273+ return llvm::to_vector (llvm::reverse (llvm::seq<unsigned >(rank)));
275274}
276275
277276SmallVector<unsigned > DpasEncodingAttr::getThreadOrder () const {
278277 size_t rank = getWarpsPerCTA ().size ();
279- SmallVector<unsigned > order (rank);
280- std::iota (order.rbegin (), order.rend (), 0 );
281- return order;
278+ return llvm::to_vector (llvm::reverse (llvm::seq<unsigned >(rank)));
282279}
283280
284281SmallVector<unsigned > DpasEncodingAttr::getWarpsPerCTA () const {
@@ -305,29 +302,33 @@ DpasEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
305302 int opIdx) const {
306303 auto parentShapePerCTATile = getShapePerCTATile (shape);
307304 size_t rank = parentShapePerCTATile.size ();
305+ assert ((rank == 2 || rank == 3 ) && " unexpected rank number for Dpas layout" );
308306 if (opIdx == 0 ) {
309307 auto shapeA = getShapeA ();
310- if (rank == 2 )
311- return {parentShapePerCTATile[0 ], shapeA[1 ]};
312- else
313- return {parentShapePerCTATile[0 ], parentShapePerCTATile[rank - 2 ],
314- shapeA[rank - 1 ]};
315- } else if (opIdx == 1 ) {
308+ return (rank == 2 )
309+ ? SmallVector<unsigned >{parentShapePerCTATile[0 ], shapeA[1 ]}
310+ : SmallVector<unsigned >{parentShapePerCTATile[0 ],
311+ parentShapePerCTATile[rank - 2 ],
312+ shapeA[rank - 1 ]};
313+ }
314+
315+ if (opIdx == 1 ) {
316316 auto shapeB = getShapeB ();
317- if (rank == 2 )
318- return {shapeB[0 ], parentShapePerCTATile[1 ]};
319- else
320- return {parentShapePerCTATile[0 ], shapeB[rank - 2 ],
321- parentShapePerCTATile[rank - 1 ]};
322- } else {
323- llvm::report_fatal_error (" DotOperandEncodingAttr opIdx must be 0 or 1" );
317+ return (rank == 2 )
318+ ? SmallVector<unsigned >{shapeB[0 ], parentShapePerCTATile[1 ]}
319+ : SmallVector<unsigned >{parentShapePerCTATile[0 ],
320+ shapeB[rank - 2 ],
321+ parentShapePerCTATile[rank - 1 ]};
324322 }
323+
324+ llvm::report_fatal_error (" DotOperandEncodingAttr opIdx must be 0 or 1" );
325325}
326326
327327SmallVector<unsigned >
328328DpasEncodingAttr::getSizePerThreadForOperands (unsigned opIdx) const {
329329 ArrayRef<unsigned > repCluster = getRepCluster ();
330330 size_t rank = repCluster.size ();
331+ assert ((rank == 2 || rank == 3 ) && " unexpected rank number for Dpas layout" );
331332 if (opIdx == 0 ) {
332333 SmallVector<unsigned > shapeA = getDPASInstShapeA ();
333334 unsigned subGroupSize = getSubGroupSize ();
@@ -345,7 +346,9 @@ DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
345346 }
346347 unsigned rowsPerWarp = mlir::ceil<unsigned >(subGroupSize, packedColNum);
347348 return {shapeA[0 ] / rowsPerWarp * repCluster[rank - 2 ], packedOpsPerLane};
348- } else if (opIdx == 1 ) {
349+ }
350+
351+ if (opIdx == 1 ) {
349352 auto shapeB = getShapeB ();
350353 auto subGroupSize = getSubGroupSize ();
351354 auto executionSize = getExecutionSize ();
@@ -357,10 +360,9 @@ DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
357360 executionSize};
358361 return {shapeB[rank - 2 ] / threadsPerWarp[0 ],
359362 shapeB[rank - 1 ] / threadsPerWarp[1 ] * repCluster[rank - 1 ]};
360- } else {
361- llvm::report_fatal_error (" DotOperandEncodingAttr opIdx must be 0 or 1" );
362- return {};
363363 }
364+
365+ llvm::report_fatal_error (" DotOperandEncodingAttr opIdx must be 0 or 1" );
364366}
365367
366368SmallVector<unsigned > DpasEncodingAttr::getElemsPerThreadForOperands (
@@ -389,15 +391,17 @@ SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
389391 // a[N][threadsPerWarp]
390392 if (threadsPerWarp > instShapeC[1 ]) {
391393 return contigPerThread;
392- } else if (threadsPerWarp == instShapeC[1 ]) {
394+ }
395+
396+ if (threadsPerWarp == instShapeC[1 ]) {
393397 auto repCluster = getRepCluster ();
394398 contigPerThread[rank - 2 ] = instShapeC[0 ] * repCluster[rank - 2 ];
395399 return contigPerThread;
396- } else {
397- // threadsPerWarp < shapeC[1]
398- llvm::report_fatal_error (" DpasEncodingAttr sub-group size could not "
399- " be smaller than the threads required per row." );
400400 }
401+
402+ // threadsPerWarp < shapeC[1]
403+ llvm::report_fatal_error (" DpasEncodingAttr sub-group size could not "
404+ " be smaller than the threads required per row." );
401405}
402406
403407LogicalResult DpasEncodingAttr::verify (
0 commit comments