Skip to content

Commit 3636bef

Browse files
committed
Make isExpensiveLoadOrStore consider blocked pointers load and stores
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 0887245 commit 3636bef

File tree

2 files changed

+62
-4
lines changed

2 files changed

+62
-4
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,29 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
305305
// Return true if the op is an op with a layout we don't want to change. We will
306306
// propagate the layout starting from anchor ops.
307307
bool isLayoutAnchor(Operation *op) {
308-
if (isa<LoadOp, StoreOp>(op))
308+
if (isa<LoadOp>(op)) {
309+
#ifdef HACK
310+
// Note: currently block ptr loads are always considered not expensive and
311+
// therefore they are never layout anchors.
312+
Value base = op->getOperand(0);
313+
auto parentLoop = op->getParentOfType<scf::ForOp>();
314+
bool isInLoop = parentLoop != nullptr;
315+
bool isTensorPtrLoad = mlir::triton::isTensorPointerType(base.getType());
316+
317+
if (!isTensorPtrLoad)
318+
ttgi::isExpensiveLoadOrStore(op);
319+
320+
// HACK: consider block ptr loads expensive if they are in a loop.
321+
return isInLoop;
322+
#else
309323
return ttgi::isExpensiveLoadOrStore(op);
324+
#endif
325+
}
326+
327+
if (isa<StoreOp>(op)) {
328+
return ttgi::isExpensiveLoadOrStore(op);
329+
}
330+
310331
if (isa<DotOp, AtomicCASOp>(op))
311332
return true;
312333
if (isa<AtomicRMWOp>(op))
@@ -356,6 +377,17 @@ void LayoutPropagation::initAnchorLayout() {
356377
}
357378
}
358379
});
380+
381+
#if 0
382+
llvm::errs() << "Initial layouts:\n";
383+
for (auto &entry : layouts) {
384+
llvm::errs() << entry.first << "\n";
385+
for (auto &layout : entry.second.encodings) {
386+
llvm::errs() << " " << layout << "\n";
387+
}
388+
}
389+
llvm::errs() << "\n\n";
390+
#endif
359391
}
360392

361393
void LayoutPropagation::setEncoding(ValueRange values, LayoutInfo &info,
@@ -969,8 +1001,28 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
9691001
}
9701002

9711003
bool canBeRemat(Operation *op) {
972-
if (isa<LoadOp, StoreOp>(op))
1004+
if (isa<LoadOp>(op)) {
1005+
#ifdef HACK
1006+
// Note: currently block ptr loads are always considered not expensive and
1007+
// therefore rematerializable.
1008+
Value base = op->getOperand(0);
1009+
auto parentLoop = op->getParentOfType<scf::ForOp>();
1010+
bool isInLoop = parentLoop != nullptr;
1011+
bool isTensorPtrLoad = mlir::triton::isTensorPointerType(base.getType());
1012+
1013+
if (!isTensorPtrLoad)
1014+
return !ttgi::isExpensiveLoadOrStore(op);
1015+
1016+
// HACK: consider block ptr loads expensive if they are in a loop.
1017+
return !isInLoop;
1018+
#else
1019+
return !ttgi::isExpensiveLoadOrStore(op);
1020+
#endif
1021+
}
1022+
1023+
if (isa<StoreOp>(op))
9731024
return !ttgi::isExpensiveLoadOrStore(op);
1025+
9741026
if (isa<AtomicRMWOp, AtomicCASOp, DotOp>(op))
9751027
return false;
9761028
if (isa<scf::WhileOp, scf::ConditionOp>(op))

third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,15 @@ bool isExpensiveLoadOrStore(Operation *op) {
9090
if (isSingleValue(base))
9191
return false;
9292

93-
// Case 2: Tensor of pointers has more threads than elements
94-
// we can presume a high hit-rate that makes it cheap to load
93+
// Case 2: Tensor of pointers has more threads than elements
94+
// we can presume a high hit-rate that makes it cheap to load
95+
96+
#define NEW 1
97+
#ifdef NEW
98+
if (auto ptrType = getRankedTensorType(base.getType())) {
99+
#else
95100
if (auto ptrType = dyn_cast<RankedTensorType>(base.getType())) {
101+
#endif
96102
auto mod = op->getParentOfType<ModuleOp>();
97103
int numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
98104
int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod);

0 commit comments

Comments
 (0)