Skip to content

Commit 9af439d

Browse files
committed
fixup case 0/1, working on case 2
1 parent d58e707 commit 9af439d

File tree

1 file changed

+41
-26
lines changed

1 file changed

+41
-26
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,29 +33,26 @@ namespace {
3333
/// - it does not have Dpas layout or Dot layout (with Dpas layout as parent)
3434
/// - its pitch is not divisible by Qword bitwidth
3535
/// - it is not contiguous in memory
36-
bool shouldRemove(tt::MakeTensorPtrOp &op, const bool isUsedByLoadOrStoreOp,
37-
const bool isUsedByBlockLoadOrStoreOp) {
36+
bool shouldRemove(tt::MakeTensorPtrOp &op, const bool isUsedByStoreOp,
37+
const bool isUsedByBlockLoadOp) {
3838
LDBG("Considering removal of: " << op);
3939
if (!op->getParentOfType<ModuleOp>()->hasAttr(
4040
ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName())) {
4141
LDBG("Marked for removal: 2D block operation not supported");
4242
return true;
4343
}
4444

45-
if (isUsedByBlockLoadOrStoreOp) {
46-
LDBG("Used by block load/store, skipping removal");
47-
return false;
48-
}
49-
5045
auto ptrType = cast<tt::PointerType>(op.getType());
5146
LDBG("Op ptr type: " << ptrType);
5247
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
5348
LDBG("Op tensor type: " << tensorType);
54-
LDBG("Used by load or store op? " << isUsedByLoadOrStoreOp);
49+
LDBG("Used by store op? " << isUsedByStoreOp);
50+
LDBG("Used by block load op? " << isUsedByBlockLoadOp);
5551

56-
if (ttgi::hasDotDpasEncoding(tensorType) &&
57-
(isUsedByLoadOrStoreOp && ttgi::hasDpasEncoding(tensorType))) {
58-
LDBG("Tensor with DPAS layout is used by load/store op with DPAS layout, "
52+
LDBG("hasDotDpasEncoding: " << ttgi::hasDotDpasEncoding(tensorType));
53+
LDBG("hasDpasEncoding: " << ttgi::hasDpasEncoding(tensorType));
54+
if (ttgi::hasDotDpasEncoding(tensorType) || isUsedByBlockLoadOp || (isUsedByStoreOp && ttgi::hasDpasEncoding(tensorType))) {
55+
LDBG("Tensor has DPAS layout or is used by load/store op with DPAS layout, "
5956
"skipping removal");
6057
return false;
6158
}
@@ -692,18 +689,9 @@ class TritonIntelGPURewriteTensorPointerPass
692689
});
693690
};
694691

695-
auto markTensorPointerForRemoval =
696-
[this](Value val, bool isUsedByLoadOrStoreOp = false,
697-
bool isUsedByBlockLoadOrStoreOp = false) {
698-
if (tt::isTensorPointerType(val.getType())) {
699-
tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val);
700-
if (shouldRemove(makeTensorPtrOp, isUsedByLoadOrStoreOp,
701-
isUsedByBlockLoadOrStoreOp)) {
702-
valueToRemove.insert(val);
703-
}
704-
}
705-
};
706692

693+
// TODO: this is working, but materialize block pointer needs to
694+
DenseSet<Operation *> tensorPointersToRemove;
707695
mod.walk([&](Operation *op) {
708696
if (isa<tt::MakeTensorPtrOp>(op)) {
709697
DenseSet<Operation *> workingSet;
@@ -724,12 +712,12 @@ class TritonIntelGPURewriteTensorPointerPass
724712
if (isa<tt::LoadOp, tt::StoreOp>(crtOp)) {
725713
LDBG("is load store, should remove?");
726714
if (shouldRemove(
727-
makeTensorPtrOp, /*isUsedByLoadOrStoreOp=*/true,
728-
/*isBlockLoadOrStore=*/
729-
crtOp->hasAttr(
715+
makeTensorPtrOp, /*isUsedByStoreOp=*/isa<tt::StoreOp>(crtOp),
716+
/*isBlockLoad=*/
717+
isa<tt::LoadOp>(crtOp) && crtOp->hasAttr(
730718
ttgi::TritonIntelGPUDialect::getBlockIOAttrName()))) {
731719
LDBG("Removing: " << result);
732-
valueToRemove.insert(result);
720+
tensorPointersToRemove.insert(makeTensorPtrOp);
733721
}
734722
} else if (auto forOp = dyn_cast<scf::ForOp>(crtOp)) {
735723
for (auto [arg, blockArg] :
@@ -790,6 +778,33 @@ class TritonIntelGPURewriteTensorPointerPass
790778
#endif
791779
});
792780

781+
auto markTensorPointerForRemoval =
782+
[this, &tensorPointersToRemove](Value val, bool isUsedByLoadOrStoreOp = false,
783+
bool isUsedByBlockLoadOrStoreOp = false) {
784+
if (tt::isTensorPointerType(val.getType())) {
785+
tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val);
786+
if (tensorPointersToRemove.count(makeTensorPtrOp)) {
787+
valueToRemove.insert(val);
788+
}
789+
}
790+
};
791+
792+
mod.walk([&](Operation *op) {
793+
if (isa<tt::MakeTensorPtrOp>(op)) {
794+
Value result = op->getResult(0);
795+
markTensorPointerForRemoval(result, usedByLoadOrStoreOp(result));
796+
} else if (isa<tt::AdvanceOp, tt::LoadOp, tt::StoreOp>(op)) {
797+
markTensorPointerForRemoval(op->getOperand(0),
798+
isa<tt::LoadOp, tt::StoreOp>(op));
799+
} else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
800+
for (auto arg : forOp.getInitArgs())
801+
markTensorPointerForRemoval(arg);
802+
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
803+
for (auto operand : yieldOp.getOperands())
804+
markTensorPointerForRemoval(operand);
805+
}
806+
});
807+
793808
LLVM_DEBUG({
794809
if (valueToRemove.empty())
795810
DBGS() << "No tensor pointer to remove";

0 commit comments

Comments
 (0)