Skip to content

Commit 84fbef0

Browse files
committed
Revert "[Reland] Fix handling of unvisited operands in AxisInfoAnalysis (#8758)"
This reverts commit 31281bc.
1 parent 6de4a5d commit 84fbef0

File tree

2 files changed

+19
-30
lines changed

2 files changed

+19
-30
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,10 +1079,11 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver,
10791079
LogicalResult AxisInfoAnalysis::visitOperation(
10801080
Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
10811081
ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
1082-
// If any operands are not yet ready, skip this operation for now.
1082+
// TODO: For sure not the right way to do this
1083+
// but why is scf.if not initialized otherwise?
10831084
for (auto op : operands)
10841085
if (op->getValue().getRank() == 0)
1085-
return success();
1086+
setToEntryState((dataflow::Lattice<AxisInfo> *)op);
10861087
AxisInfo curr = visitors.apply(op, operands);
10871088
if (curr.getRank() == 0) {
10881089
setAllToEntryStates(results);
@@ -1111,11 +1112,9 @@ void AxisInfoAnalysis::visitForOpInductionVar(
11111112
ProgramPoint *programPoint = getProgramPointAfter(op);
11121113
auto *lbLattice = getLatticeElementFor(programPoint, op.getLowerBound());
11131114
auto *stepLattice = getLatticeElementFor(programPoint, op.getStep());
1114-
// If lb or step is not yet ready, skip this operation for now.
1115-
if (lbLattice->getValue().getRank() == 0 ||
1116-
stepLattice->getValue().getRank() == 0) {
1117-
return;
1118-
}
1115+
for (auto op_iter : {lbLattice, stepLattice})
1116+
if (op_iter->getValue().getRank() == 0)
1117+
setToEntryState((dataflow::Lattice<AxisInfo> *)op_iter);
11191118

11201119
AxisInfo::DimVectorT knownContiguity(1, 1);
11211120
AxisInfo::DimVectorT knownDivisibility(1, 1);
@@ -1189,15 +1188,24 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
11891188
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
11901189
&knownContiguity, &knownDivisibility,
11911190
&knownConstancy);
1192-
} else if (isa<gpu::WarpSpecializePartitionsOp>(op)) {
1193-
// Initialize the arguments to gpu::WarpSpecializePartitionsOp with
1194-
// "unknown" state: the maximum possible divisibility, contiguity, and
1195-
// constancy.
1191+
} else if (isa<RegionBranchOpInterface, gpu::WarpSpecializePartitionsOp>(
1192+
op)) {
1193+
// scf::ForOp, scf::IfOp, scf::WhileOp, gpu::WarpSpecializePartitionsOp
1194+
// Control flow operations are initialized with "unknown" state:
1195+
// the maximum possible divisibility, contiguity, and constancy.
11961196
knownDivisibility = DimVectorT(rank, kMaxDivisor);
11971197
knownConstancy = DimVectorT(rank, kMaxDivisor);
11981198
knownContiguity = DimVectorT(rank, kMaxDivisor);
11991199
}
12001200
} else if (Operation *op = value.getDefiningOp()) {
1201+
if (isa<RegionBranchOpInterface>(op)) {
1202+
// scf::ForOp, scf::IfOp, scf::WhileOp
1203+
// Control flow operations are initialized with "unknown" state:
1204+
// the maximum possible divisibility, contiguity, and constancy.
1205+
knownDivisibility = DimVectorT(rank, kMaxDivisor);
1206+
knownConstancy = DimVectorT(rank, kMaxDivisor);
1207+
knownContiguity = DimVectorT(rank, kMaxDivisor);
1208+
}
12011209
// Other operations are conservatively initialized with the lowest possible
12021210
// divisibility, contiguity, and constancy unless they have specified.
12031211
AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.divisibility"),
@@ -1350,10 +1358,6 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp,
13501358
auto *axisInfoMap = getFuncData(funcOp);
13511359
auto updateAxisInfoMap = [&](Value value) {
13521360
auto axisInfo = analysis->getLatticeElement(value)->getValue();
1353-
// If we could not determine the AxisInfo for this value, assume the
1354-
// pessimistic state.
1355-
if (axisInfo.getRank() == 0)
1356-
axisInfo = AxisInfo::getPessimisticValueState(value);
13571361
AxisInfo curAxisInfo;
13581362
if (axisInfoMap->count(value)) {
13591363
curAxisInfo = AxisInfo::join(axisInfo, axisInfoMap->lookup(value));

test/Analysis/test-alignment.mlir

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,18 +1089,3 @@ tt.func public @test_inductor_for() {
10891089
}
10901090
tt.return
10911091
}
1092-
1093-
// -----
1094-
1095-
// Verify that if an operation is statically determined to be dead, we fall back
1096-
// to assigning it a pessimistic value, rather than skipping it entirely.
1097-
tt.func @dead_op_pessimistic() {
1098-
%c5 = arith.constant dense<5> : tensor<4xi32>
1099-
%c7 = arith.constant dense<7> : tensor<4xi32>
1100-
%false = arith.constant false
1101-
scf.if %false {
1102-
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>}}
1103-
%add = arith.addi %c5, %c7 : tensor<4xi32>
1104-
}
1105-
tt.return
1106-
}

0 commit comments

Comments
 (0)