@@ -194,7 +194,7 @@ DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
194194 }
195195}
196196
197- unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperands (
197+ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand (
198198 ArrayRef<int64_t > shape, mlir::Type eltTy, int kWidth , int opIdx) const {
199199 auto shapePerCTA = getShapePerCTA (*this , shape);
200200 auto rep = getDPASRepetitions (shapePerCTA, opIdx);
@@ -234,8 +234,8 @@ SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
234234}
235235
236236SmallVector<unsigned >
237- DpasEncodingAttr::getShapePerCTATileForDotOperands (ArrayRef<int64_t > shape,
238- int opIdx) const {
237+ DpasEncodingAttr::getShapePerCTATileForOperand (ArrayRef<int64_t > shape,
238+ int kWidth , int opIdx) const {
239239 auto parentShapePerCTATile = getShapePerCTATile (shape);
240240 auto threadsPerWarp = getThreadsPerWarp ();
241241 if (opIdx == 0 ) {
@@ -250,7 +250,7 @@ DpasEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
250250}
251251
252252SmallVector<unsigned >
253- DpasEncodingAttr::getSizePerThreadForOperands ( unsigned opIdx) const {
253+ DpasEncodingAttr::getSizePerThreadForOperand ( int kWidth , unsigned opIdx) const {
254254 if (opIdx == 0 ) {
255255 SmallVector<unsigned > shapeA = getDPASInstShapeA ();
256256 unsigned subGroupSize = getSubGroupSize ();
@@ -290,7 +290,7 @@ DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
290290
291291SmallVector<unsigned > DpasEncodingAttr::getElemsPerThreadForOperands (
292292 ArrayRef<int64_t > shape, Type eltTy, unsigned opIdx) const {
293- SmallVector<unsigned > sizePerThread = getSizePerThreadForOperands ( opIdx);
293+ SmallVector<unsigned > sizePerThread = getSizePerThreadForOperand ( 0 , opIdx);
294294 SmallVector<int64_t > repetitions = getDPASRepetitions (shape, opIdx);
295295
296296 return {static_cast <unsigned >(sizePerThread[0 ] * repetitions[0 ]),
0 commit comments