Skip to content

Commit 365d43e

Browse files
committed
format and cleanup
1 parent e0e595b commit 365d43e

File tree

1 file changed

+16
-69
lines changed

1 file changed

+16
-69
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp

Lines changed: 16 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -674,21 +674,6 @@ class TritonIntelGPURewriteTensorPointerPass
674674
void runOnOperation() override {
675675
ModuleOp mod = getOperation();
676676

677-
// TODO: do we need this attribute?
678-
auto usedByLoadOrStoreOp = [](Value val,
679-
const bool check_block_io_attribute = false) {
680-
return llvm::any_of(
681-
val.getUsers(), [check_block_io_attribute](Operation *user) {
682-
const bool is_load_or_store = isa<tt::LoadOp, tt::StoreOp>(user);
683-
if (check_block_io_attribute) {
684-
return user->hasAttr(
685-
ttgi::TritonIntelGPUDialect::getBlockIOAttrName());
686-
} else {
687-
return is_load_or_store;
688-
}
689-
});
690-
};
691-
692677
DenseSet<Operation *> tensorPointersToRemove;
693678
mod.walk([&](Operation *op) {
694679
if (isa<tt::MakeTensorPtrOp>(op)) {
@@ -698,24 +683,22 @@ class TritonIntelGPURewriteTensorPointerPass
698683
LDBG("Considering: " << *op);
699684
Value result = op->getResult(0);
700685
for (auto user : result.getUsers()) {
701-
workingSet.insert(user); // TODO: safe? need to check ptr?
686+
workingSet.insert(user);
702687
}
703688
while (!workingSet.empty()) {
704-
for (auto v : workingSet) {
705-
LDBG("Working set val: " << *v);
706-
}
707689
auto crtOpItr = workingSet.begin();
708690
auto crtOp = *crtOpItr;
709691
LDBG("Processing op: " << *crtOp);
710692
if (isa<tt::LoadOp, tt::StoreOp>(crtOp)) {
711-
LDBG("is load store, should remove?");
693+
LDBG("is load store, checking to see if we should remove make "
694+
"tensor ptr op");
712695
if (shouldRemove(makeTensorPtrOp,
713696
/*isUsedByStoreOp=*/isa<tt::StoreOp>(crtOp),
714697
/*isBlockLoad=*/
715698
isa<tt::LoadOp>(crtOp) &&
716699
crtOp->hasAttr(ttgi::TritonIntelGPUDialect::
717700
getBlockIOAttrName()))) {
718-
LDBG("Removing: " << result);
701+
LDBG("Marking op for removal: " << result);
719702
tensorPointersToRemove.insert(makeTensorPtrOp);
720703
}
721704
} else if (auto forOp = dyn_cast<scf::ForOp>(crtOp)) {
@@ -730,14 +713,8 @@ class TritonIntelGPURewriteTensorPointerPass
730713
}
731714
}
732715
}
733-
#if 0
734-
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
735-
for (auto operand : yieldOp.getOperands()) {
736-
workingSet.insert(operand->getResult(0));
737-
}
738-
#endif
739716
} else if (crtOp->getNumResults() > 0) {
740-
// TODO: handle more than one result?
717+
// TODO: should we handle more than one result?
741718
auto crtOpResult = crtOp->getResult(0);
742719
LDBG("Not a load store and not a loop, adding users to working "
743720
"set.");
@@ -747,55 +724,25 @@ class TritonIntelGPURewriteTensorPointerPass
747724
}
748725
workingSet.erase(crtOpItr);
749726
}
750-
#if 1
751-
}
752-
#else
753-
} else if (isa<tt::AdvanceOp, tt::LoadOp, tt::StoreOp>(op)) {
754-
const bool isLoadStoreOp = isa<tt::LoadOp, tt::StoreOp>(op);
755-
markTensorPointerForRemoval(
756-
op->getOperand(0), isLoadStoreOp,
757-
isLoadStoreOp &&
758-
op->hasAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName()));
759-
} else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
760-
for (auto [arg, blockArg] :
761-
llvm::zip(forOp.getInitArgs(),
762-
forOp.getBody()->getArguments().drop_front(
763-
forOp.getNumInductionVars()))) {
764-
LDBG("arg: " << arg);
765-
if (isa<tt::MakeTensorPtrOp>(arg.getDefiningOp())) {
766-
constexpr bool check_block_io_attribute = true;
767-
markTensorPointerForRemoval(
768-
arg.getDefiningOp()->getResult(0),
769-
usedByLoadOrStoreOp(blockArg),
770-
usedByLoadOrStoreOp(blockArg, check_block_io_attribute));
771-
}
772-
}
773-
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
774-
for (auto operand : yieldOp.getOperands())
775-
markTensorPointerForRemoval(operand);
776727
}
777-
#endif
778728
});
779729

780-
auto markTensorPointerForRemoval =
781-
[this,
782-
&tensorPointersToRemove](Value val, bool isUsedByLoadOrStoreOp = false,
783-
bool isUsedByBlockLoadOrStoreOp = false) {
784-
if (tt::isTensorPointerType(val.getType())) {
785-
tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val);
786-
if (tensorPointersToRemove.count(makeTensorPtrOp)) {
787-
valueToRemove.insert(val);
788-
}
789-
}
790-
};
730+
auto markTensorPointerForRemoval = [this,
731+
&tensorPointersToRemove](Value val) {
732+
if (tt::isTensorPointerType(val.getType())) {
733+
tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val);
734+
if (tensorPointersToRemove.count(makeTensorPtrOp)) {
735+
valueToRemove.insert(val);
736+
}
737+
}
738+
};
791739

792740
mod.walk([&](Operation *op) {
793741
if (isa<tt::MakeTensorPtrOp>(op)) {
794742
Value result = op->getResult(0);
795-
markTensorPointerForRemoval(result, usedByLoadOrStoreOp(result));
743+
markTensorPointerForRemoval(result);
796744
} else if (isa<tt::AdvanceOp, tt::LoadOp, tt::StoreOp>(op)) {
797-
markTensorPointerForRemoval(op->getOperand(0),
798-
isa<tt::LoadOp, tt::StoreOp>(op));
745+
markTensorPointerForRemoval(op->getOperand(0));
799746
} else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
800747
for (auto arg : forOp.getInitArgs())
801748
markTensorPointerForRemoval(arg);

0 commit comments

Comments
 (0)