Skip to content

Commit d58e707

Browse files
committed
broken: need to remove every tensor ptr type in the chain
1 parent bf881ed commit d58e707

File tree

1 file changed

+88
-13
lines changed

1 file changed

+88
-13
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp

Lines changed: 88 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,36 @@ namespace {
3333
/// - it does not have Dpas layout or Dot layout (with Dpas layout as parent)
3434
/// - its pitch is not divisible by Qword bitwidth
3535
/// - it is not contiguous in memory
36-
bool shouldRemove(tt::MakeTensorPtrOp &op, const bool isUsedByLoadOrStoreOp) {
36+
bool shouldRemove(tt::MakeTensorPtrOp &op, const bool isUsedByLoadOrStoreOp,
37+
const bool isUsedByBlockLoadOrStoreOp) {
3738
LDBG("Considering removal of: " << op);
3839
if (!op->getParentOfType<ModuleOp>()->hasAttr(
3940
ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName())) {
4041
LDBG("Marked for removal: 2D block operation not supported");
4142
return true;
4243
}
4344

45+
if (isUsedByBlockLoadOrStoreOp) {
46+
LDBG("Used by block load/store, skipping removal");
47+
return false;
48+
}
49+
4450
auto ptrType = cast<tt::PointerType>(op.getType());
4551
LDBG("Op ptr type: " << ptrType);
4652
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
4753
LDBG("Op tensor type: " << tensorType);
54+
LDBG("Used by load or store op? " << isUsedByLoadOrStoreOp);
4855

49-
if (!ttgi::hasDotDpasEncoding(tensorType) &&
50-
!(isUsedByLoadOrStoreOp && ttgi::hasDpasEncoding(tensorType))) {
51-
LDBG("Marked for removal: tensor doesn't have DPAS layout and is not used "
52-
"by load or store op with DPAS layout");
53-
return true;
56+
if (ttgi::hasDotDpasEncoding(tensorType) &&
57+
(isUsedByLoadOrStoreOp && ttgi::hasDpasEncoding(tensorType))) {
58+
LDBG("Tensor with DPAS layout is used by load/store op with DPAS layout, "
59+
"skipping removal");
60+
return false;
5461
}
55-
return false;
62+
63+
LDBG("Marked for removal: tensor doesn't have DPAS layout and is not used "
64+
"by load or store op with DPAS layout");
65+
return true;
5666
}
5767

5868
/// The `RewritedInfo` struct is used to store information about a rewritten
@@ -683,37 +693,101 @@ class TritonIntelGPURewriteTensorPointerPass
683693
};
684694

685695
auto markTensorPointerForRemoval =
686-
[this](Value val, bool isUsedByLoadOrStoreOp = false) {
696+
[this](Value val, bool isUsedByLoadOrStoreOp = false,
697+
bool isUsedByBlockLoadOrStoreOp = false) {
687698
if (tt::isTensorPointerType(val.getType())) {
688699
tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val);
689-
if (shouldRemove(makeTensorPtrOp, isUsedByLoadOrStoreOp))
700+
if (shouldRemove(makeTensorPtrOp, isUsedByLoadOrStoreOp,
701+
isUsedByBlockLoadOrStoreOp)) {
690702
valueToRemove.insert(val);
703+
}
691704
}
692705
};
693706

694707
mod.walk([&](Operation *op) {
695708
if (isa<tt::MakeTensorPtrOp>(op)) {
709+
DenseSet<Operation *> workingSet;
710+
711+
auto makeTensorPtrOp = dyn_cast<tt::MakeTensorPtrOp>(op);
712+
LDBG("Considering: " << *op);
696713
Value result = op->getResult(0);
697-
markTensorPointerForRemoval(result, usedByLoadOrStoreOp(result));
714+
for (auto user : result.getUsers()) {
715+
workingSet.insert(user); // TODO: safe? need to check ptr?
716+
}
717+
while (!workingSet.empty()) {
718+
for (auto v : workingSet) {
719+
LDBG("Working set val: " << *v);
720+
}
721+
auto crtOpItr = workingSet.begin();
722+
auto crtOp = *crtOpItr;
723+
LDBG("Processing op: " << *crtOp);
724+
if (isa<tt::LoadOp, tt::StoreOp>(crtOp)) {
725+
LDBG("is load store, should remove?");
726+
if (shouldRemove(
727+
makeTensorPtrOp, /*isUsedByLoadOrStoreOp=*/true,
728+
/*isBlockLoadOrStore=*/
729+
crtOp->hasAttr(
730+
ttgi::TritonIntelGPUDialect::getBlockIOAttrName()))) {
731+
LDBG("Removing: " << result);
732+
valueToRemove.insert(result);
733+
}
734+
} else if (auto forOp = dyn_cast<scf::ForOp>(crtOp)) {
735+
for (auto [arg, blockArg] :
736+
llvm::zip(forOp.getInitArgs(),
737+
forOp.getBody()->getArguments().drop_front(
738+
forOp.getNumInductionVars()))) {
739+
if (arg == makeTensorPtrOp) {
740+
// add users of block arg
741+
for (auto user : blockArg.getUsers()) {
742+
workingSet.insert(user);
743+
}
744+
}
745+
}
746+
#if 0
747+
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
748+
for (auto operand : yieldOp.getOperands()) {
749+
workingSet.insert(operand->getResult(0));
750+
}
751+
#endif
752+
} else if (crtOp->getNumResults() > 0) {
753+
// TODO: handle more than one result?
754+
auto crtOpResult = crtOp->getResult(0);
755+
LDBG("Not a load store and not a loop, adding users to working "
756+
"set.");
757+
for (auto user : crtOpResult.getUsers()) {
758+
workingSet.insert(user);
759+
}
760+
}
761+
workingSet.erase(crtOpItr);
762+
}
763+
#if 1
764+
}
765+
#else
698766
} else if (isa<tt::AdvanceOp, tt::LoadOp, tt::StoreOp>(op)) {
699-
markTensorPointerForRemoval(op->getOperand(0),
700-
isa<tt::LoadOp, tt::StoreOp>(op));
767+
const bool isLoadStoreOp = isa<tt::LoadOp, tt::StoreOp>(op);
768+
markTensorPointerForRemoval(
769+
op->getOperand(0), isLoadStoreOp,
770+
isLoadStoreOp &&
771+
op->hasAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName()));
701772
} else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
702773
for (auto [arg, blockArg] :
703774
llvm::zip(forOp.getInitArgs(),
704775
forOp.getBody()->getArguments().drop_front(
705776
forOp.getNumInductionVars()))) {
777+
LDBG("arg: " << arg);
706778
if (isa<tt::MakeTensorPtrOp>(arg.getDefiningOp())) {
707779
constexpr bool check_block_io_attribute = true;
708780
markTensorPointerForRemoval(
709781
arg.getDefiningOp()->getResult(0),
782+
usedByLoadOrStoreOp(blockArg),
710783
usedByLoadOrStoreOp(blockArg, check_block_io_attribute));
711784
}
712785
}
713786
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
714787
for (auto operand : yieldOp.getOperands())
715788
markTensorPointerForRemoval(operand);
716789
}
790+
#endif
717791
});
718792

719793
LLVM_DEBUG({
@@ -722,7 +796,7 @@ class TritonIntelGPURewriteTensorPointerPass
722796
else {
723797
DBGS() << "Values to remove: ";
724798
for (auto val : valueToRemove)
725-
DBGS() << val;
799+
DBGS() << val << "\n";
726800
}
727801
});
728802

@@ -746,6 +820,7 @@ class TritonIntelGPURewriteTensorPointerPass
746820
valueToRemove.clear();
747821
while (!eraser.empty()) {
748822
auto op = eraser.top();
823+
LDBG("DELETING " << *op);
749824
eraser.pop();
750825
op->erase();
751826
}

0 commit comments

Comments
 (0)