Skip to content

Commit 1efa822

Browse files
[intel] Fix getTotalElemsPerThread failures
Signed-off-by: Whitney Tsang <[email protected]>
1 parent ddc8f87 commit 1efa822

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,10 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
992992
return amdWmmaParent.getTotalElemsPerThreadForOperand(
993993
shape, eltTy, getKWidth(), getOpIdx());
994994
}
995+
if (auto dpasParent = mlir::dyn_cast<intel::DpasEncodingAttr>(mmaParent)) {
996+
return dpasParent.getTotalElemsPerThreadForOperand(
997+
shape, eltTy, getKWidth(), getOpIdx());
998+
}
995999
}
9961000
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
9971001
auto shapePerCTA = getShapePerCTA(*this, shape);

0 commit comments

Comments
 (0)