1414#define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE " ]: " )
1515#define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
1616
17+ // #define OLD 1
18+
1719namespace mlir ::triton {
1820namespace {
1921
@@ -1083,7 +1085,12 @@ LogicalResult AxisInfoAnalysis::visitOperation(
10831085 // but why is scf.if not initialized otherwise?
10841086 for (auto op : operands)
10851087 if (op->getValue ().getRank () == 0 )
1088+ #ifdef OLD
10861089 setToEntryState ((dataflow::Lattice<AxisInfo> *)op);
1090+ #else
1091+ return success ();
1092+ #endif
1093+
10871094 AxisInfo curr = visitors.apply (op, operands);
10881095 if (curr.getRank () == 0 ) {
10891096 setAllToEntryStates (results);
@@ -1112,9 +1119,17 @@ void AxisInfoAnalysis::visitForOpInductionVar(
11121119 ProgramPoint *programPoint = getProgramPointAfter (op);
11131120 auto *lbLattice = getLatticeElementFor (programPoint, op.getLowerBound ());
11141121 auto *stepLattice = getLatticeElementFor (programPoint, op.getStep ());
1122+ #ifdef OLD
11151123 for (auto op_iter : {lbLattice, stepLattice})
11161124 if (op_iter->getValue ().getRank () == 0 )
11171125 setToEntryState ((dataflow::Lattice<AxisInfo> *)op_iter);
1126+ #else
1127+ // If lb or step is not yet ready, skip this operation for now.
1128+ if (lbLattice->getValue ().getRank () == 0 ||
1129+ stepLattice->getValue ().getRank () == 0 ) {
1130+ return ;
1131+ }
1132+ #endif
11181133
11191134 AxisInfo::DimVectorT knownContiguity (1 , 1 );
11201135 AxisInfo::DimVectorT knownDivisibility (1 , 1 );
@@ -1188,16 +1203,25 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
11881203 initPessimisticStateFromFunc (blockArg.getArgNumber (), fun,
11891204 &knownContiguity, &knownDivisibility,
11901205 &knownConstancy);
1191- } else if (isa<RegionBranchOpInterface, gpu::WarpSpecializePartitionsOp>(
1192- op)) {
1206+ }
1207+ #ifdef OLD
1208+ else if (isa<RegionBranchOpInterface, gpu::WarpSpecializePartitionsOp>(
1209+ op)) {
11931210 // scf::ForOp, scf::IfOp, scf::WhileOp, gpu::WarpSpecializePartitionsOp
11941211 // Control flow operations are initialized with "unknown" state:
11951212 // the maximum possible divisibility, contiguity, and constancy.
1213+ #else
1214+ else if (isa<gpu::WarpSpecializePartitionsOp>(op)) {
1215+ // Initialize the arguments to gpu::WarpSpecializePartitionsOp with
1216+ // "unknown" state: the maximum possible divisibility, contiguity, and
1217+ // constancy.
1218+ #endif
11961219 knownDivisibility = DimVectorT (rank, kMaxDivisor );
11971220 knownConstancy = DimVectorT (rank, kMaxDivisor );
11981221 knownContiguity = DimVectorT (rank, kMaxDivisor );
11991222 }
12001223 } else if (Operation *op = value.getDefiningOp ()) {
1224+ #ifdef OLD
12011225 if (isa<RegionBranchOpInterface>(op)) {
12021226 // scf::ForOp, scf::IfOp, scf::WhileOp
12031227 // Control flow operations are initialized with "unknown" state:
@@ -1206,6 +1230,7 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
12061230 knownConstancy = DimVectorT (rank, kMaxDivisor );
12071231 knownContiguity = DimVectorT (rank, kMaxDivisor );
12081232 }
1233+ #endif
12091234 // Other operations are conservatively initialized with the lowest possible
12101235 // divisibility, contiguity, and constancy unless they have specified.
12111236 AxisInfo::initDimVectorFromHint (op->getDiscardableAttr (" tt.divisibility" ),
@@ -1358,6 +1383,12 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp,
13581383 auto *axisInfoMap = getFuncData (funcOp);
13591384 auto updateAxisInfoMap = [&](Value value) {
13601385 auto axisInfo = analysis->getLatticeElement (value)->getValue ();
1386+ #ifdef OLD
1387+ // If we could not determine the AxisInfo for this value, assume the
1388+ // pessimistic state.
1389+ if (axisInfo.getRank () == 0 )
1390+ axisInfo = AxisInfo::getPessimisticValueState (value);
1391+ #endif
13611392 auto &valInfo = (*axisInfoMap)[value];
13621393 valInfo = AxisInfo::join (axisInfo, valInfo);
13631394 };
0 commit comments