Skip to content

Commit cbc630b

Browse files
committed
MaterializeBlockPointer fix for GEMM with 1st operand transposed
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 7c9a0f9 commit cbc630b

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
#include "mlir/Dialect/Arith/IR/Arith.h"
55
#include "mlir/IR/Visitors.h"
66
#include "triton/Analysis/Utility.h"
7+
#include "llvm/Support/Casting.h"
78
#include "llvm/Support/Debug.h"
9+
#include <optional>
810

911
#define DEBUG_TYPE "tritonintelgpu-materialize-block-pointer"
1012
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -131,9 +133,12 @@ struct TritonIntelGPUMaterializeBlockPointerPass
131133
getDotLayout(tt::LoadOp loadOp) const {
132134
Value ptr = loadOp.getPtr();
133135
if (!tt::isTensorPointerType(ptr.getType()))
134-
return nullptr;
136+
return std::nullopt;
135137

136138
RankedTensorType tensorType = ttgi::getRankedTensorType(ptr.getType());
139+
if (!tensorType)
140+
return std::nullopt;
141+
137142
auto dotLayout = ttgi::getDotEncoding(tensorType);
138143
if (dotLayout)
139144
return dotLayout;
@@ -154,13 +159,15 @@ struct TritonIntelGPUMaterializeBlockPointerPass
154159
};
155160

156161
Operation::user_range users = loadOp->getUsers();
157-
if (allUsersAreConvertOps(users) && allUserHaveIdenticalLayout(users)) {
162+
if (!users.empty() && allUsersAreConvertOps(users) &&
163+
allUserHaveIdenticalLayout(users)) {
158164
Attribute firstUserLayout =
159165
cast<ttg::ConvertLayoutOp>(*users.begin()).getType().getEncoding();
160-
return dyn_cast<ttg::DotOperandEncodingAttr>(firstUserLayout);
166+
return llvm::dyn_cast_if_present<ttg::DotOperandEncodingAttr>(
167+
firstUserLayout);
161168
}
162169

163-
return nullptr;
170+
return std::nullopt;
164171
}
165172
};
166173

0 commit comments

Comments
 (0)