@@ -247,7 +247,7 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
247247 warpsPerCTA[rank - 1 ]))};
248248}
249249
250- unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperands (
250+ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand (
251251 ArrayRef<int64_t > shape, mlir::Type eltTy, int kWidth , int opIdx) const {
252252 auto shapePerCTA = getShapePerCTA (*this , shape);
253253 auto rep = getDPASRepetitions (shapePerCTA, opIdx);
@@ -298,8 +298,8 @@ SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
298298}
299299
300300SmallVector<unsigned >
301- DpasEncodingAttr::getShapePerCTATileForDotOperands (ArrayRef<int64_t > shape,
302- int opIdx) const {
301+ DpasEncodingAttr::getShapePerCTATileForOperand (ArrayRef<int64_t > shape,
302+ int kWidth , int opIdx) const {
303303 auto parentShapePerCTATile = getShapePerCTATile (shape);
304304 size_t rank = parentShapePerCTATile.size ();
305305 assert ((rank == 2 || rank == 3 ) && " unexpected rank number for Dpas layout" );
@@ -325,7 +325,7 @@ DpasEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
325325}
326326
327327SmallVector<unsigned >
328- DpasEncodingAttr::getSizePerThreadForOperands ( unsigned opIdx) const {
328+ DpasEncodingAttr::getSizePerThreadForOperand ( int kWidth , unsigned opIdx) const {
329329 ArrayRef<unsigned > repCluster = getRepCluster ();
330330 size_t rank = repCluster.size ();
331331 assert ((rank == 2 || rank == 3 ) && " unexpected rank number for Dpas layout" );
@@ -367,7 +367,7 @@ DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
367367
368368SmallVector<unsigned > DpasEncodingAttr::getElemsPerThreadForOperands (
369369 ArrayRef<int64_t > shape, Type eltTy, unsigned opIdx) const {
370- SmallVector<unsigned > sizePerThread = getSizePerThreadForOperands ( opIdx);
370+ SmallVector<unsigned > sizePerThread = getSizePerThreadForOperand ( 0 , opIdx);
371371 SmallVector<int64_t > repetitions = getDPASRepetitions (shape, opIdx);
372372
373373 size_t rank = shape.size ();
0 commit comments