@@ -153,9 +153,9 @@ DpasEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
153153 SmallVector<unsigned > warpsPerCTA = getWarpsPerCTA ();
154154 size_t rank = shapeC.size ();
155155 SmallVector<unsigned > shapePerCTATile (rank);
156- for ( size_t i = 0 ; i < rank; ++i) {
157- shapePerCTATile[i] = shapeC[i] * warpsPerCTA[i];
158- }
156+ llvm::transform (
157+ llvm::zip_equal ( shapeC, warpsPerCTA), shapePerCTATile. begin (),
158+ []( auto entry) { return std::get< 0 >(entry) * std::get< 1 >(entry); });
159159 return shapePerCTATile;
160160}
161161
@@ -220,7 +220,9 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
220220 std::max<int64_t >(1 , shape[rank - 2 ] / (shapePerWarp[rank - 2 ] *
221221 warpsPerCTA[rank - 2 ])),
222222 std::max<int64_t >(1 , shape[rank - 1 ] / shapePerWarp[rank - 1 ])};
223- } else if (opIdx == 1 ) {
223+ }
224+
225+ if (opIdx == 1 ) {
224226 auto shapePerWarp = getShapeB ();
225227 int64_t numRepBatch =
226228 rank == 3 ? std::max<int64_t >(1 , shape[0 ] /
@@ -230,28 +232,27 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
230232 std::max<int64_t >(1 , shape[rank - 2 ] / shapePerWarp[rank - 2 ]),
231233 std::max<int64_t >(1 , shape[rank - 1 ] / (shapePerWarp[rank - 1 ] *
232234 warpsPerCTA[rank - 1 ]))};
233- } else {
234- assert (opIdx == 2 && " Unexpected operand id (valid ids are 0, 1 or 2)" );
235- auto shapePerWarp = getShapeC ();
236- int64_t numRepBatch =
237- rank == 3 ? std::max<int64_t >(1 , shape[0 ] /
238- (shapePerWarp[0 ] * warpsPerCTA[0 ]))
239- : 1 ;
240- return {numRepBatch,
241- std::max<int64_t >(1 , shape[rank - 2 ] / (shapePerWarp[rank - 2 ] *
242- warpsPerCTA[rank - 2 ])),
243- std::max<int64_t >(1 , shape[rank - 1 ] / (shapePerWarp[rank - 1 ] *
244- warpsPerCTA[rank - 1 ]))};
245235 }
246- return rep;
236+
237+ assert (opIdx == 2 && " Unexpected operand id (valid ids are 0, 1 or 2)" );
238+ auto shapePerWarp = getShapeC ();
239+ int64_t numRepBatch =
240+ rank == 3
241+ ? std::max<int64_t >(1 , shape[0 ] / (shapePerWarp[0 ] * warpsPerCTA[0 ]))
242+ : 1 ;
243+ return {numRepBatch,
244+ std::max<int64_t >(1 , shape[rank - 2 ] / (shapePerWarp[rank - 2 ] *
245+ warpsPerCTA[rank - 2 ])),
246+ std::max<int64_t >(1 , shape[rank - 1 ] / (shapePerWarp[rank - 1 ] *
247+ warpsPerCTA[rank - 1 ]))};
247248}
248249
249250unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperands (
250251 ArrayRef<int64_t > shape, mlir::Type eltTy, int kWidth , int opIdx) const {
251252 auto shapePerCTA = getShapePerCTA (*this , shape);
252253 auto rep = getDPASRepetitions (shapePerCTA, opIdx);
253254 auto threadsPerWar = getSubGroupSize ();
254- int rank = shape.size ();
255+ size_t rank = shape.size ();
255256 if (opIdx == 0 ) {
256257 auto shapeA = getShapeA ();
257258 auto totalElem = product<unsigned >(shapeA);
@@ -269,16 +270,12 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperands(
269270
270271SmallVector<unsigned > DpasEncodingAttr::getWarpOrder () const {
271272 size_t rank = getWarpsPerCTA ().size ();
272- SmallVector<unsigned > order (rank);
273- std::iota (order.rbegin (), order.rend (), 0 );
274- return order;
273+ return llvm::to_vector (llvm::reverse (llvm::seq<unsigned >(rank)));
275274}
276275
277276SmallVector<unsigned > DpasEncodingAttr::getThreadOrder () const {
278277 size_t rank = getWarpsPerCTA ().size ();
279- SmallVector<unsigned > order (rank);
280- std::iota (order.rbegin (), order.rend (), 0 );
281- return order;
278+ return llvm::to_vector (llvm::reverse (llvm::seq<unsigned >(rank)));
282279}
283280
284281SmallVector<unsigned > DpasEncodingAttr::getWarpsPerCTA () const {
@@ -307,18 +304,18 @@ DpasEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
307304 size_t rank = parentShapePerCTATile.size ();
308305 if (opIdx == 0 ) {
309306 auto shapeA = getShapeA ();
310- if (rank == 2 )
311- return {parentShapePerCTATile[0 ], shapeA[1 ]};
312- else
313- return {parentShapePerCTATile[ 0 ], parentShapePerCTATile[rank - 2 ],
314- shapeA[rank - 1 ]};
307+ return (rank == 2 )
308+ ? SmallVector< unsigned > {parentShapePerCTATile[0 ], shapeA[1 ]}
309+ : SmallVector< unsigned >{parentShapePerCTATile[ 0 ],
310+ parentShapePerCTATile[rank - 2 ],
311+ shapeA[rank - 1 ]};
315312 } else if (opIdx == 1 ) {
316313 auto shapeB = getShapeB ();
317- if (rank == 2 )
318- return {shapeB[0 ], parentShapePerCTATile[1 ]};
319- else
320- return {parentShapePerCTATile[ 0 ], shapeB[rank - 2 ],
321- parentShapePerCTATile[rank - 1 ]};
314+ return (rank == 2 )
315+ ? SmallVector< unsigned > {shapeB[0 ], parentShapePerCTATile[1 ]}
316+ : SmallVector< unsigned >{parentShapePerCTATile[ 0 ],
317+ shapeB[rank - 2 ],
318+ parentShapePerCTATile[rank - 1 ]};
322319 } else {
323320 llvm::report_fatal_error (" DotOperandEncodingAttr opIdx must be 0 or 1" );
324321 }
0 commit comments