Skip to content

Commit d402827

Browse files
committed
TEST
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 0bf934b commit d402827

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
1515
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
1616

17+
// #define OLD 1
18+
1719
namespace mlir::triton {
1820
namespace {
1921

@@ -1083,7 +1085,12 @@ LogicalResult AxisInfoAnalysis::visitOperation(
10831085
// but why is scf.if not initialized otherwise?
10841086
for (auto op : operands)
10851087
if (op->getValue().getRank() == 0)
1088+
#ifdef OLD
10861089
setToEntryState((dataflow::Lattice<AxisInfo> *)op);
1090+
#else
1091+
return success();
1092+
#endif
1093+
10871094
AxisInfo curr = visitors.apply(op, operands);
10881095
if (curr.getRank() == 0) {
10891096
setAllToEntryStates(results);
@@ -1112,9 +1119,17 @@ void AxisInfoAnalysis::visitForOpInductionVar(
11121119
ProgramPoint *programPoint = getProgramPointAfter(op);
11131120
auto *lbLattice = getLatticeElementFor(programPoint, op.getLowerBound());
11141121
auto *stepLattice = getLatticeElementFor(programPoint, op.getStep());
1122+
#ifdef OLD
11151123
for (auto op_iter : {lbLattice, stepLattice})
11161124
if (op_iter->getValue().getRank() == 0)
11171125
setToEntryState((dataflow::Lattice<AxisInfo> *)op_iter);
1126+
#else
1127+
// If lb or step is not yet ready, skip this operation for now.
1128+
if (lbLattice->getValue().getRank() == 0 ||
1129+
stepLattice->getValue().getRank() == 0) {
1130+
return;
1131+
}
1132+
#endif
11181133

11191134
AxisInfo::DimVectorT knownContiguity(1, 1);
11201135
AxisInfo::DimVectorT knownDivisibility(1, 1);
@@ -1188,16 +1203,25 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
11881203
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
11891204
&knownContiguity, &knownDivisibility,
11901205
&knownConstancy);
1191-
} else if (isa<RegionBranchOpInterface, gpu::WarpSpecializePartitionsOp>(
1192-
op)) {
1206+
}
1207+
#ifdef OLD
1208+
else if (isa<RegionBranchOpInterface, gpu::WarpSpecializePartitionsOp>(
1209+
op)) {
11931210
// scf::ForOp, scf::IfOp, scf::WhileOp, gpu::WarpSpecializePartitionsOp
11941211
// Control flow operations are initialized with "unknown" state:
11951212
// the maximum possible divisibility, contiguity, and constancy.
1213+
#else
1214+
else if (isa<gpu::WarpSpecializePartitionsOp>(op)) {
1215+
// Initialize the arguments to gpu::WarpSpecializePartitionsOp with
1216+
// "unknown" state: the maximum possible divisibility, contiguity, and
1217+
// constancy.
1218+
#endif
11961219
knownDivisibility = DimVectorT(rank, kMaxDivisor);
11971220
knownConstancy = DimVectorT(rank, kMaxDivisor);
11981221
knownContiguity = DimVectorT(rank, kMaxDivisor);
11991222
}
12001223
} else if (Operation *op = value.getDefiningOp()) {
1224+
#ifdef OLD
12011225
if (isa<RegionBranchOpInterface>(op)) {
12021226
// scf::ForOp, scf::IfOp, scf::WhileOp
12031227
// Control flow operations are initialized with "unknown" state:
@@ -1206,6 +1230,7 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
12061230
knownConstancy = DimVectorT(rank, kMaxDivisor);
12071231
knownContiguity = DimVectorT(rank, kMaxDivisor);
12081232
}
1233+
#endif
12091234
// Other operations are conservatively initialized with the lowest possible
12101235
// divisibility, contiguity, and constancy unless they have specified.
12111236
AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.divisibility"),
@@ -1358,6 +1383,12 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp,
13581383
auto *axisInfoMap = getFuncData(funcOp);
13591384
auto updateAxisInfoMap = [&](Value value) {
13601385
auto axisInfo = analysis->getLatticeElement(value)->getValue();
1386+
#ifdef OLD
1387+
// If we could not determine the AxisInfo for this value, assume the
1388+
// pessimistic state.
1389+
if (axisInfo.getRank() == 0)
1390+
axisInfo = AxisInfo::getPessimisticValueState(value);
1391+
#endif
13611392
auto &valInfo = (*axisInfoMap)[value];
13621393
valInfo = AxisInfo::join(axisInfo, valInfo);
13631394
};

0 commit comments

Comments
 (0)