@@ -305,29 +305,8 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
305305// Return true if the op is an op with a layout we don't want to change. We will
306306// propagate the layout starting from anchor ops.
307307bool isLayoutAnchor (Operation *op) {
308- if (isa<LoadOp>(op)) {
309- #ifdef HACK
310- // Note: currently block ptr loads are always considered not expensive and
311- // therefore they are never layout anchors.
312- Value base = op->getOperand (0 );
313- auto parentLoop = op->getParentOfType <scf::ForOp>();
314- bool isInLoop = parentLoop != nullptr ;
315- bool isTensorPtrLoad = mlir::triton::isTensorPointerType (base.getType ());
316-
317- if (!isTensorPtrLoad)
318- ttgi::isExpensiveLoadOrStore (op);
319-
320- // HACK: consider block ptr loads expensive if they are in a loop.
321- return isInLoop;
322- #else
308+ if (isa<LoadOp, StoreOp>(op))
323309 return ttgi::isExpensiveLoadOrStore (op);
324- #endif
325- }
326-
327- if (isa<StoreOp>(op)) {
328- return ttgi::isExpensiveLoadOrStore (op);
329- }
330-
331310 if (isa<DotOp, AtomicCASOp>(op))
332311 return true ;
333312 if (isa<AtomicRMWOp>(op))
@@ -377,17 +356,6 @@ void LayoutPropagation::initAnchorLayout() {
377356 }
378357 }
379358 });
380-
381- #if 0
382- llvm::errs() << "Initial layouts:\n";
383- for (auto &entry : layouts) {
384- llvm::errs() << entry.first << "\n";
385- for (auto &layout : entry.second.encodings) {
386- llvm::errs() << " " << layout << "\n";
387- }
388- }
389- llvm::errs() << "\n\n";
390- #endif
391359}
392360
393361void LayoutPropagation::setEncoding (ValueRange values, LayoutInfo &info,
@@ -1001,28 +969,8 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
1001969}
1002970
1003971bool canBeRemat (Operation *op) {
1004- if (isa<LoadOp>(op)) {
1005- #ifdef HACK
1006- // Note: currently block ptr loads are always considered not expensive and
1007- // therefore rematerializable.
1008- Value base = op->getOperand (0 );
1009- auto parentLoop = op->getParentOfType <scf::ForOp>();
1010- bool isInLoop = parentLoop != nullptr ;
1011- bool isTensorPtrLoad = mlir::triton::isTensorPointerType (base.getType ());
1012-
1013- if (!isTensorPtrLoad)
1014- return !ttgi::isExpensiveLoadOrStore (op);
1015-
1016- // HACK: consider block ptr loads expensive if they are in a loop.
1017- return !isInLoop;
1018- #else
1019- return !ttgi::isExpensiveLoadOrStore (op);
1020- #endif
1021- }
1022-
1023- if (isa<StoreOp>(op))
972+ if (isa<LoadOp, StoreOp>(op))
1024973 return !ttgi::isExpensiveLoadOrStore (op);
1025-
1026974 if (isa<AtomicRMWOp, AtomicCASOp, DotOp>(op))
1027975 return false ;
1028976 if (isa<scf::WhileOp, scf::ConditionOp>(op))
0 commit comments