Skip to content

Commit 778bd61

Browse files
committed
Reapply "[Reland] Fix handling of unvisited operands in AxisInfoAnalysis (#8758)"
This reverts commit 546a718.
1 parent 8a1fdb4 commit 778bd61

File tree

2 files changed

+30
-19
lines changed

2 files changed

+30
-19
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,11 +1079,10 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver,
10791079
LogicalResult AxisInfoAnalysis::visitOperation(
10801080
Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
10811081
ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
1082-
// TODO: For sure not the right way to do this
1083-
// but why is scf.if not initialized otherwise?
1082+
// If any operands are not yet ready, skip this operation for now.
10841083
for (auto op : operands)
10851084
if (op->getValue().getRank() == 0)
1086-
setToEntryState((dataflow::Lattice<AxisInfo> *)op);
1085+
return success();
10871086
AxisInfo curr = visitors.apply(op, operands);
10881087
if (curr.getRank() == 0) {
10891088
setAllToEntryStates(results);
@@ -1112,9 +1111,11 @@ void AxisInfoAnalysis::visitForOpInductionVar(
11121111
ProgramPoint *programPoint = getProgramPointAfter(op);
11131112
auto *lbLattice = getLatticeElementFor(programPoint, op.getLowerBound());
11141113
auto *stepLattice = getLatticeElementFor(programPoint, op.getStep());
1115-
for (auto op_iter : {lbLattice, stepLattice})
1116-
if (op_iter->getValue().getRank() == 0)
1117-
setToEntryState((dataflow::Lattice<AxisInfo> *)op_iter);
1114+
// If lb or step is not yet ready, skip this operation for now.
1115+
if (lbLattice->getValue().getRank() == 0 ||
1116+
stepLattice->getValue().getRank() == 0) {
1117+
return;
1118+
}
11181119

11191120
AxisInfo::DimVectorT knownContiguity(1, 1);
11201121
AxisInfo::DimVectorT knownDivisibility(1, 1);
@@ -1188,24 +1189,15 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
11881189
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
11891190
&knownContiguity, &knownDivisibility,
11901191
&knownConstancy);
1191-
} else if (isa<RegionBranchOpInterface, gpu::WarpSpecializePartitionsOp>(
1192-
op)) {
1193-
// scf::ForOp, scf::IfOp, scf::WhileOp, gpu::WarpSpecializePartitionsOp
1194-
// Control flow operations are initialized with "unknown" state:
1195-
// the maximum possible divisibility, contiguity, and constancy.
1192+
} else if (isa<gpu::WarpSpecializePartitionsOp>(op)) {
1193+
// Initialize the arguments to gpu::WarpSpecializePartitionsOp with
1194+
// "unknown" state: the maximum possible divisibility, contiguity, and
1195+
// constancy.
11961196
knownDivisibility = DimVectorT(rank, kMaxDivisor);
11971197
knownConstancy = DimVectorT(rank, kMaxDivisor);
11981198
knownContiguity = DimVectorT(rank, kMaxDivisor);
11991199
}
12001200
} else if (Operation *op = value.getDefiningOp()) {
1201-
if (isa<RegionBranchOpInterface>(op)) {
1202-
// scf::ForOp, scf::IfOp, scf::WhileOp
1203-
// Control flow operations are initialized with "unknown" state:
1204-
// the maximum possible divisibility, contiguity, and constancy.
1205-
knownDivisibility = DimVectorT(rank, kMaxDivisor);
1206-
knownConstancy = DimVectorT(rank, kMaxDivisor);
1207-
knownContiguity = DimVectorT(rank, kMaxDivisor);
1208-
}
12091201
// Other operations are conservatively initialized with the lowest possible
12101202
// divisibility, contiguity, and constancy unless they have specified.
12111203
AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.divisibility"),
@@ -1358,6 +1350,10 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp,
13581350
auto *axisInfoMap = getFuncData(funcOp);
13591351
auto updateAxisInfoMap = [&](Value value) {
13601352
auto axisInfo = analysis->getLatticeElement(value)->getValue();
1353+
// If we could not determine the AxisInfo for this value, assume the
1354+
// pessimistic state.
1355+
if (axisInfo.getRank() == 0)
1356+
axisInfo = AxisInfo::getPessimisticValueState(value);
13611357
auto &valInfo = (*axisInfoMap)[value];
13621358
valInfo = AxisInfo::join(axisInfo, valInfo);
13631359
};

test/Analysis/test-alignment.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,3 +1089,18 @@ tt.func public @test_inductor_for() {
10891089
}
10901090
tt.return
10911091
}
1092+
1093+
// -----
1094+
1095+
// Verify that if an operation is statically determined to be dead, we fall back
1096+
// to assigning it a pessimistic value, rather than skipping it entirely.
1097+
tt.func @dead_op_pessimistic() {
1098+
%c5 = arith.constant dense<5> : tensor<4xi32>
1099+
%c7 = arith.constant dense<7> : tensor<4xi32>
1100+
%false = arith.constant false
1101+
scf.if %false {
1102+
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>}}
1103+
%add = arith.addi %c5, %c7 : tensor<4xi32>
1104+
}
1105+
tt.return
1106+
}

0 commit comments

Comments
 (0)