Skip to content

Commit 22188cc

Browse files
authored
Revert "Fix Axis analysis" (#8166)
Reverts triton-lang/triton#8144 This exposed a bug in MLIR upstream. This will get merged when we integrate the fix: llvm/llvm-project#158359
1 parent 4f7a8b8 commit 22188cc

File tree

3 files changed

+62
-4
lines changed

3 files changed

+62
-4
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1085,7 +1085,10 @@ void AxisInfoAnalysis::visitForOpInductionVar(
10851085
AxisInfo::DimVectorT knownContiguity(1, 1);
10861086
AxisInfo::DimVectorT knownDivisibility(1, 1);
10871087
AxisInfo::DimVectorT knownConstancy(1, 1);
1088-
knownDivisibility[0] = gcd(lb.getDivisibility(0), step.getDivisibility(0));
1088+
auto lbDivisibility = lb.getDivisibility();
1089+
auto stepDivisibility = step.getDivisibility();
1090+
if (!lbDivisibility.empty() && !stepDivisibility.empty())
1091+
knownDivisibility[0] = gcd(lbDivisibility[0], stepDivisibility[0]);
10891092
auto inductionVar =
10901093
AxisInfo(knownContiguity, knownDivisibility, knownConstancy);
10911094
(void)argLattices[0]->join(inductionVar);

lib/Analysis/Utility.cpp

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1165,10 +1165,63 @@ SetVector<Operation *> multiRootGetSlice(Operation *op,
11651165
return multiRootTopologicalSort(slice);
11661166
}
11671167

1168+
namespace {
1169+
// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis
1170+
// interacts with constant propagation, but SparseConstantPropagation
1171+
// doesn't seem to be sufficient.
1172+
class ConstantAnalysis : public DataFlowAnalysis {
1173+
public:
1174+
using DataFlowAnalysis::DataFlowAnalysis;
1175+
1176+
LogicalResult initialize(Operation *top) override {
1177+
WalkResult result = top->walk([&](Operation *op) {
1178+
ProgramPoint programPoint(op);
1179+
if (failed(visit(&programPoint)))
1180+
return WalkResult::interrupt();
1181+
return WalkResult::advance();
1182+
});
1183+
return success(!result.wasInterrupted());
1184+
}
1185+
1186+
LogicalResult visit(ProgramPoint *point) override {
1187+
Operation *op = point->getOperation();
1188+
Attribute value;
1189+
if (matchPattern(op, m_Constant(&value))) {
1190+
auto *constant = getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(
1191+
op->getResult(0));
1192+
propagateIfChanged(constant, constant->join(dataflow::ConstantValue(
1193+
value, op->getDialect())));
1194+
return success();
1195+
}
1196+
// Dead code analysis requires every operands has initialized ConstantValue
1197+
// state before it is visited.
1198+
// https://github.com/llvm/llvm-project/blob/2ec1aba2b69faa1de5f71832a48e25aa3b5d5314/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp#L322
1199+
// That's why we need to set all operands to unknown constants.
1200+
setAllToUnknownConstants(op->getResults());
1201+
for (Region &region : op->getRegions()) {
1202+
for (Block &block : region.getBlocks())
1203+
setAllToUnknownConstants(block.getArguments());
1204+
}
1205+
return success();
1206+
}
1207+
1208+
private:
1209+
/// Set all given values as not constants.
1210+
void setAllToUnknownConstants(ValueRange values) {
1211+
dataflow::ConstantValue unknownConstant(nullptr, nullptr);
1212+
for (Value value : values) {
1213+
auto *constant =
1214+
getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(value);
1215+
propagateIfChanged(constant, constant->join(unknownConstant));
1216+
}
1217+
}
1218+
};
1219+
} // namespace
1220+
11681221
std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
11691222
auto solver = std::make_unique<DataFlowSolver>();
11701223
solver->load<dataflow::DeadCodeAnalysis>();
1171-
solver->load<dataflow::SparseConstantPropagation>();
1224+
solver->load<ConstantAnalysis>();
11721225
return solver;
11731226
}
11741227

test/Analysis/test-alignment.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,8 @@ tt.func @if_into_for_init(%i1 : i1) {
929929
scf.yield %cst128 : i32
930930
}
931931
scf.for %i = %ret to %cst128 step %cst_64 : i32 {
932-
// expected-remark @below {{contiguity = [1], divisibility = [64], constancy = [1], constant_value = <none>}}
932+
// TODO: Wrong divisibility here. Fix it once llvm/llvm-project#158359 lands
933+
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>}}
933934
%t = arith.addi %i, %c0 : i32
934935
}
935936
tt.return
@@ -948,7 +949,8 @@ tt.func @if_into_for_step(%i1 : i1) {
948949
scf.yield %cst128 : i32
949950
}
950951
scf.for %i = %c0 to %cst128 step %ret : i32 {
951-
// expected-remark @below {{contiguity = [1], divisibility = [64], constancy = [1], constant_value = <none>}}
952+
// TODO: Wrong divisibility here. Fix it once llvm/llvm-project#158359 lands
953+
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>}}
952954
%t = arith.addi %i, %c0 : i32
953955
}
954956
tt.return

0 commit comments

Comments
 (0)