@@ -305,8 +305,29 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
305305// Return true if the op is an op with a layout we don't want to change. We will
306306// propagate the layout starting from anchor ops.
307307bool isLayoutAnchor (Operation *op) {
308- if (isa<LoadOp, StoreOp>(op))
308+ if (isa<LoadOp>(op)) {
309+ #ifdef HACK
310+ // Note: currently block ptr loads are always considered not expensive and
311+ // therefore they are never layout anchors.
312+ Value base = op->getOperand (0 );
313+ auto parentLoop = op->getParentOfType <scf::ForOp>();
314+ bool isInLoop = parentLoop != nullptr ;
315+ bool isTensorPtrLoad = mlir::triton::isTensorPointerType (base.getType ());
316+
317+ if (!isTensorPtrLoad)
318+ ttgi::isExpensiveLoadOrStore (op);
319+
320+ // HACK: consider block ptr loads expensive if they are in a loop.
321+ return isInLoop;
322+ #else
309323 return ttgi::isExpensiveLoadOrStore (op);
324+ #endif
325+ }
326+
327+ if (isa<StoreOp>(op)) {
328+ return ttgi::isExpensiveLoadOrStore (op);
329+ }
330+
310331 if (isa<DotOp, AtomicCASOp>(op))
311332 return true ;
312333 if (isa<AtomicRMWOp>(op))
@@ -356,6 +377,17 @@ void LayoutPropagation::initAnchorLayout() {
356377 }
357378 }
358379 });
380+
381+ #if 0
382+ llvm::errs() << "Initial layouts:\n";
383+ for (auto &entry : layouts) {
384+ llvm::errs() << entry.first << "\n";
385+ for (auto &layout : entry.second.encodings) {
386+ llvm::errs() << " " << layout << "\n";
387+ }
388+ }
389+ llvm::errs() << "\n\n";
390+ #endif
359391}
360392
361393void LayoutPropagation::setEncoding (ValueRange values, LayoutInfo &info,
@@ -969,8 +1001,28 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
9691001}
9701002
9711003bool canBeRemat (Operation *op) {
972- if (isa<LoadOp, StoreOp>(op))
1004+ if (isa<LoadOp>(op)) {
1005+ #ifdef HACK
1006+ // Note: currently block ptr loads are always considered not expensive and
1007+ // therefore rematerializable.
1008+ Value base = op->getOperand (0 );
1009+ auto parentLoop = op->getParentOfType <scf::ForOp>();
1010+ bool isInLoop = parentLoop != nullptr ;
1011+ bool isTensorPtrLoad = mlir::triton::isTensorPointerType (base.getType ());
1012+
1013+ if (!isTensorPtrLoad)
1014+ return !ttgi::isExpensiveLoadOrStore (op);
1015+
1016+ // HACK: consider block ptr loads expensive if they are in a loop.
1017+ return !isInLoop;
1018+ #else
1019+ return !ttgi::isExpensiveLoadOrStore (op);
1020+ #endif
1021+ }
1022+
1023+ if (isa<StoreOp>(op))
9731024 return !ttgi::isExpensiveLoadOrStore (op);
1025+
9741026 if (isa<AtomicRMWOp, AtomicCASOp, DotOp>(op))
9751027 return false ;
9761028 if (isa<scf::WhileOp, scf::ConditionOp>(op))
0 commit comments