Skip to content

Commit 3c8cdbf

Browse files
authored
Move getMakeTensorPtrOp func to Triton/IR/Utility.h (#6755)
To fix the problem: ```bash In file included from /runner/_work/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/include/triton/Dialect/TritonGPU/IR/Dialect.h:12, from /runner/_work/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/include/triton/Analysis/Utility.h:8, from /runner/_work/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp:9: /runner/_work/intel-xpu-backend-for-triton/intel-xpu-backend-for-triton/include/triton/Dialect/TritonGPU/IR/Types.h:12:10: fatal error: triton/Dialect/TritonGPU/IR/TypeInterfaces.h.inc: No such file or directory 12 | #include "triton/Dialect/TritonGPU/IR/TypeInterfaces.h.inc" | ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ compilation terminated. ``` Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 44abc7a commit 3c8cdbf

File tree

5 files changed

+78
-76
lines changed

5 files changed

+78
-76
lines changed

include/triton/Analysis/Utility.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,6 @@ template <typename T> class CallGraph {
411411
// Create a basic DataFlowSolver with constant and dead code analysis included.
412412
std::unique_ptr<DataFlowSolver> createDataFlowSolver();
413413

414-
triton::MakeTensorPtrOp getMakeTensorPtrOp(Value v);
415-
416414
} // namespace mlir
417415

418416
#endif // TRITON_ANALYSIS_UTILITY_H

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ template <typename T> auto seq(T start, T end, T step) {
177177
Value getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask,
178178
Value pred);
179179

180+
MakeTensorPtrOp getMakeTensorPtrOp(Value v);
181+
180182
} // namespace triton
181183
} // namespace mlir
182184

lib/Analysis/Utility.cpp

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,75 +1029,4 @@ std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
10291029
return solver;
10301030
}
10311031

1032-
static MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) {
1033-
1034-
if (auto makeTensorPtrOp = dyn_cast<MakeTensorPtrOp>(op)) {
1035-
return makeTensorPtrOp;
1036-
}
1037-
1038-
if (auto advanceOp = dyn_cast<AdvanceOp>(op)) {
1039-
return getMakeTensorPtrOp(advanceOp.getPtr());
1040-
}
1041-
1042-
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
1043-
auto idx = cast<OpResult>(v).getResultNumber();
1044-
llvm::SmallVector<scf::YieldOp> yieldOps;
1045-
op->walk([&](Operation *op) {
1046-
if (auto yieldOp = dyn_cast<scf::YieldOp>(op))
1047-
yieldOps.push_back(yieldOp);
1048-
});
1049-
1050-
// benzh@ if multi yields, all yields operand should come from same arg.
1051-
Value newValue = yieldOps[0].getOperands()[idx];
1052-
return getMakeTensorPtrOp(newValue);
1053-
}
1054-
1055-
llvm_unreachable("Unable to getMakeTensorPtr()");
1056-
}
1057-
1058-
MakeTensorPtrOp getMakeTensorPtrOp(Value v) {
1059-
using BranchOps = llvm::SetVector<std::pair<Operation *, int>>;
1060-
llvm::DenseMap<Block *, BranchOps> blockToCFOps;
1061-
auto moduleOp =
1062-
v.getParentBlock()->getParentOp()->getParentOfType<ModuleOp>();
1063-
1064-
moduleOp.walk([&](Operation *op) {
1065-
if (auto br = dyn_cast<cf::BranchOp>(op)) {
1066-
Block *block = br.getDest();
1067-
blockToCFOps[block].insert({op, -1});
1068-
}
1069-
if (auto condBr = dyn_cast<cf::CondBranchOp>(op)) {
1070-
Block *blockT = condBr.getTrueDest();
1071-
Block *blockF = condBr.getFalseDest();
1072-
blockToCFOps[blockT].insert({condBr, 1});
1073-
blockToCFOps[blockF].insert({condBr, 0});
1074-
}
1075-
});
1076-
1077-
if (Operation *definingOp = v.getDefiningOp())
1078-
return getMakeTensorPtrOpImpl(definingOp, v);
1079-
1080-
// If there is no defining op, v must be a BlockArgument.
1081-
BlockArgument arg = cast<BlockArgument>(v);
1082-
unsigned argNum = arg.getArgNumber();
1083-
Operation *argOwner = arg.getOwner()->getParentOp();
1084-
1085-
if (auto forOp = dyn_cast<scf::ForOp>(argOwner))
1086-
return getMakeTensorPtrOp(
1087-
forOp.getOperand(argNum + forOp.getNumControlOperands() - 1));
1088-
if (auto funcOp = dyn_cast<FunctionOpInterface>(argOwner)) {
1089-
Block *block = arg.getOwner();
1090-
Operation *op;
1091-
int tOrF;
1092-
std::tie(op, tOrF) = blockToCFOps[block][0];
1093-
if (auto br = dyn_cast<cf::BranchOp>(op))
1094-
return getMakeTensorPtrOp(br.getDestOperands()[argNum]);
1095-
if (auto condBr = dyn_cast<cf::CondBranchOp>(op))
1096-
return getMakeTensorPtrOp(tOrF ? condBr.getTrueDestOperands()[argNum]
1097-
: condBr.getFalseDestOperands()[argNum]);
1098-
return getMakeTensorPtrOp(argOwner->getOperand(argNum));
1099-
}
1100-
llvm_unreachable("Unable to getMakeTensorPtr()");
1101-
}
1102-
11031032
} // namespace mlir

lib/Dialect/Triton/IR/Utility.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "triton/Dialect/Triton/IR/Utility.h"
2+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
23
#include "triton/Dialect/Triton/IR/Dialect.h"
34

45
using namespace mlir;
@@ -17,3 +18,75 @@ Value tt::getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask,
1718
}
1819
return mask;
1920
}
21+
22+
static tt::MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) {
23+
24+
if (auto makeTensorPtrOp = dyn_cast<tt::MakeTensorPtrOp>(op)) {
25+
return makeTensorPtrOp;
26+
}
27+
28+
if (auto advanceOp = dyn_cast<tt::AdvanceOp>(op)) {
29+
return tt::getMakeTensorPtrOp(advanceOp.getPtr());
30+
}
31+
32+
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
33+
auto idx = cast<OpResult>(v).getResultNumber();
34+
llvm::SmallVector<scf::YieldOp> yieldOps;
35+
op->walk([&](Operation *op) {
36+
if (auto yieldOp = dyn_cast<scf::YieldOp>(op))
37+
yieldOps.push_back(yieldOp);
38+
});
39+
40+
// benzh@ if multi yields, all yields operand should come from same arg.
41+
Value newValue = yieldOps[0].getOperands()[idx];
42+
return tt::getMakeTensorPtrOp(newValue);
43+
}
44+
45+
llvm_unreachable("Unable to getMakeTensorPtr()");
46+
}
47+
48+
tt::MakeTensorPtrOp tt::getMakeTensorPtrOp(Value v) {
49+
using BranchOps = llvm::SetVector<std::pair<Operation *, int>>;
50+
llvm::DenseMap<Block *, BranchOps> blockToCFOps;
51+
auto moduleOp =
52+
v.getParentBlock()->getParentOp()->getParentOfType<ModuleOp>();
53+
54+
moduleOp.walk([&](Operation *op) {
55+
if (auto br = dyn_cast<cf::BranchOp>(op)) {
56+
Block *block = br.getDest();
57+
blockToCFOps[block].insert({op, -1});
58+
}
59+
if (auto condBr = dyn_cast<cf::CondBranchOp>(op)) {
60+
Block *blockT = condBr.getTrueDest();
61+
Block *blockF = condBr.getFalseDest();
62+
blockToCFOps[blockT].insert({condBr, 1});
63+
blockToCFOps[blockF].insert({condBr, 0});
64+
}
65+
});
66+
67+
if (Operation *definingOp = v.getDefiningOp())
68+
return getMakeTensorPtrOpImpl(definingOp, v);
69+
70+
// If there is no defining op, v must be a BlockArgument.
71+
BlockArgument arg = cast<BlockArgument>(v);
72+
unsigned argNum = arg.getArgNumber();
73+
Operation *argOwner = arg.getOwner()->getParentOp();
74+
75+
if (auto forOp = dyn_cast<scf::ForOp>(argOwner))
76+
return tt::getMakeTensorPtrOp(
77+
forOp.getOperand(argNum + forOp.getNumControlOperands() - 1));
78+
if (auto funcOp = dyn_cast<FunctionOpInterface>(argOwner)) {
79+
Block *block = arg.getOwner();
80+
Operation *op;
81+
int tOrF;
82+
std::tie(op, tOrF) = blockToCFOps[block][0];
83+
if (auto br = dyn_cast<cf::BranchOp>(op))
84+
return tt::getMakeTensorPtrOp(br.getDestOperands()[argNum]);
85+
if (auto condBr = dyn_cast<cf::CondBranchOp>(op))
86+
return tt::getMakeTensorPtrOp(
87+
tOrF ? condBr.getTrueDestOperands()[argNum]
88+
: condBr.getFalseDestOperands()[argNum]);
89+
return tt::getMakeTensorPtrOp(argOwner->getOperand(argNum));
90+
}
91+
llvm_unreachable("Unable to getMakeTensorPtr()");
92+
}

lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
#include "mlir/Dialect/SCF/IR/SCF.h"
77
#include "mlir/Pass/Pass.h"
88
#include "mlir/Support/LLVM.h"
9-
#include "triton/Analysis/Utility.h"
109
#include "triton/Dialect/Triton/IR/Dialect.h"
10+
#include "triton/Dialect/Triton/IR/Utility.h"
1111
#include "triton/Dialect/Triton/Transforms/Passes.h"
1212

1313
using namespace mlir;
@@ -340,7 +340,7 @@ class RewriteTensorPointerPass
340340
continue;
341341
}
342342
needRewrite = true;
343-
auto makeTensorPtrOp = getMakeTensorPtrOp(results[i]);
343+
auto makeTensorPtrOp = triton::getMakeTensorPtrOp(results[i]);
344344
assert(rewritedInfo.count(makeTensorPtrOp.getResult()));
345345
const auto &info = rewritedInfo[makeTensorPtrOp.getResult()];
346346
for (unsigned j = 0; j < info.length(); ++j) {
@@ -378,7 +378,7 @@ class RewriteTensorPointerPass
378378
oldResIdx++;
379379
newResIdx++;
380380
} else {
381-
auto makeTensorPtrOp = getMakeTensorPtrOp(results[oldResIdx]);
381+
auto makeTensorPtrOp = triton::getMakeTensorPtrOp(results[oldResIdx]);
382382
assert(rewritedInfo.count(makeTensorPtrOp.getResult()));
383383
auto info = rewritedInfo[makeTensorPtrOp.getResult()];
384384
for (unsigned j = 0; j < info.length(); ++j) {

0 commit comments

Comments
 (0)