Skip to content

Commit 3e283c4

Browse files
neildharanmyachev
authored andcommitted
[Reland] Fix handling of unvisited operands in AxisInfoAnalysis (#8758)
We currently force initialisation of operands that have not yet been visited with `setToEntryState`. This means that the order in which values are visited can change the results of the analysis. This can be a source of bugs. For example, the lowering for `AsyncCopyGlobalToLocalOp` validates that the load addresses permit sufficient vectorisation, however, this is up to the analysis actually recovering the same information it had when the async copy was created. Otherwise, we crash during lowering. I have an actual repro for this but it has been very difficult to minimise it enough to make it suitable for an lit test: https://gist.github.com/neildhar/7eea6a312afa39d1cc83dc12627c2ba3 Populating the operands in this way also means that we have to handle control flow like `ForOp` and `IfOp` explicitly in `setToEntryState`, because we may attempt to populate their results when we visit their users. Instead, when we encounter an operation whose operands have not yet been encountered, skip over the operation entirely. We can revisit it once the operands have actually been visited. This improves the quality of the analysis, and leaves the handling of control flow to the dataflow framework. This reland adds handling for the case where the dataflow analysis fails to initialise a particular value (likely because it is determined to be dead).
1 parent 13d7a47 commit 3e283c4

File tree

2 files changed

+30
-19
lines changed

2 files changed

+30
-19
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,11 +1079,10 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver,
10791079
LogicalResult AxisInfoAnalysis::visitOperation(
10801080
Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
10811081
ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
1082-
// TODO: For sure not the right way to do this
1083-
// but why is scf.if not initialized otherwise?
1082+
// If any operands are not yet ready, skip this operation for now.
10841083
for (auto op : operands)
10851084
if (op->getValue().getRank() == 0)
1086-
setToEntryState((dataflow::Lattice<AxisInfo> *)op);
1085+
return success();
10871086
AxisInfo curr = visitors.apply(op, operands);
10881087
if (curr.getRank() == 0) {
10891088
setAllToEntryStates(results);
@@ -1112,9 +1111,11 @@ void AxisInfoAnalysis::visitForOpInductionVar(
11121111
ProgramPoint *programPoint = getProgramPointAfter(op);
11131112
auto *lbLattice = getLatticeElementFor(programPoint, op.getLowerBound());
11141113
auto *stepLattice = getLatticeElementFor(programPoint, op.getStep());
1115-
for (auto op_iter : {lbLattice, stepLattice})
1116-
if (op_iter->getValue().getRank() == 0)
1117-
setToEntryState((dataflow::Lattice<AxisInfo> *)op_iter);
1114+
// If lb or step is not yet ready, skip this operation for now.
1115+
if (lbLattice->getValue().getRank() == 0 ||
1116+
stepLattice->getValue().getRank() == 0) {
1117+
return;
1118+
}
11181119

11191120
AxisInfo::DimVectorT knownContiguity(1, 1);
11201121
AxisInfo::DimVectorT knownDivisibility(1, 1);
@@ -1188,24 +1189,15 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
11881189
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
11891190
&knownContiguity, &knownDivisibility,
11901191
&knownConstancy);
1191-
} else if (isa<RegionBranchOpInterface, gpu::WarpSpecializePartitionsOp>(
1192-
op)) {
1193-
// scf::ForOp, scf::IfOp, scf::WhileOp, gpu::WarpSpecializePartitionsOp
1194-
// Control flow operations are initialized with "unknown" state:
1195-
// the maximum possible divisibility, contiguity, and constancy.
1192+
} else if (isa<gpu::WarpSpecializePartitionsOp>(op)) {
1193+
// Initialize the arguments to gpu::WarpSpecializePartitionsOp with
1194+
// "unknown" state: the maximum possible divisibility, contiguity, and
1195+
// constancy.
11961196
knownDivisibility = DimVectorT(rank, kMaxDivisor);
11971197
knownConstancy = DimVectorT(rank, kMaxDivisor);
11981198
knownContiguity = DimVectorT(rank, kMaxDivisor);
11991199
}
12001200
} else if (Operation *op = value.getDefiningOp()) {
1201-
if (isa<RegionBranchOpInterface>(op)) {
1202-
// scf::ForOp, scf::IfOp, scf::WhileOp
1203-
// Control flow operations are initialized with "unknown" state:
1204-
// the maximum possible divisibility, contiguity, and constancy.
1205-
knownDivisibility = DimVectorT(rank, kMaxDivisor);
1206-
knownConstancy = DimVectorT(rank, kMaxDivisor);
1207-
knownContiguity = DimVectorT(rank, kMaxDivisor);
1208-
}
12091201
// Other operations are conservatively initialized with the lowest possible
12101202
// divisibility, contiguity, and constancy unless they have specified.
12111203
AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.divisibility"),
@@ -1358,6 +1350,10 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp,
13581350
auto *axisInfoMap = getFuncData(funcOp);
13591351
auto updateAxisInfoMap = [&](Value value) {
13601352
auto axisInfo = analysis->getLatticeElement(value)->getValue();
1353+
// If we could not determine the AxisInfo for this value, assume the
1354+
// pessimistic state.
1355+
if (axisInfo.getRank() == 0)
1356+
axisInfo = AxisInfo::getPessimisticValueState(value);
13611357
AxisInfo curAxisInfo;
13621358
if (axisInfoMap->count(value)) {
13631359
curAxisInfo = AxisInfo::join(axisInfo, axisInfoMap->lookup(value));

test/Analysis/test-alignment.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,3 +1089,18 @@ tt.func public @test_inductor_for() {
10891089
}
10901090
tt.return
10911091
}
1092+
1093+
// -----
1094+
1095+
// Verify that if an operation is statically determined to be dead, we fall back
1096+
// to assigning it a pessimistic value, rather than skipping it entirely.
1097+
tt.func @dead_op_pessimistic() {
1098+
%c5 = arith.constant dense<5> : tensor<4xi32>
1099+
%c7 = arith.constant dense<7> : tensor<4xi32>
1100+
%false = arith.constant false
1101+
scf.if %false {
1102+
// expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>}}
1103+
%add = arith.addi %c5, %c7 : tensor<4xi32>
1104+
}
1105+
tt.return
1106+
}

0 commit comments

Comments
 (0)