Skip to content

Commit 4630ec2

Browse files
Temporary hack to obtain correct ThreadsPerWarp for DPAS layout
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 9706b7c commit 4630ec2

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3002,7 +3002,15 @@ struct TritonGPUVerifyTensorLayoutInterface
30023002
// Number of threads per warp.
30033003
auto kLane = StringAttr::get(module.getContext(), "lane");
30043004
int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module);
3005-
if (ll.getInDimSize(kLane) != moduleThreadsPerWarp) {
3005+
// FIXME: ll.getInDimSize(kLane) does not return the correct threads per
3006+
// warp. https://github.com/intel/intel-xpu-backend-for-triton/issues/4861
3007+
unsigned layoutThreadsPerWarp = ll.getInDimSize(kLane);
3008+
if (auto dotOperandLayout =
3009+
dyn_cast<DotOperandEncodingAttr>(rankedTy.getEncoding()))
3010+
if (auto dpasLayout =
3011+
dyn_cast<intel::DpasEncodingAttr>(dotOperandLayout.getParent()))
3012+
layoutThreadsPerWarp = dpasLayout.getThreadsPerWarp();
3013+
if (layoutThreadsPerWarp != moduleThreadsPerWarp) {
30063014
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kLane)
30073015
<< " threads per warp, but the module specifies "
30083016
<< moduleThreadsPerWarp << " threads per warp.";

0 commit comments

Comments
 (0)