@@ -1079,10 +1079,11 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver,
10791079LogicalResult AxisInfoAnalysis::visitOperation (
10801080 Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
10811081 ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
1082- // If any operands are not yet ready, skip this operation for now.
1082+ // TODO: For sure not the right way to do this
1083+ // but why is scf.if not initialized otherwise?
10831084 for (auto op : operands)
10841085 if (op->getValue ().getRank () == 0 )
1085- return success ( );
1086+ setToEntryState ((dataflow::Lattice<AxisInfo> *)op );
10861087 AxisInfo curr = visitors.apply (op, operands);
10871088 if (curr.getRank () == 0 ) {
10881089 setAllToEntryStates (results);
@@ -1111,11 +1112,9 @@ void AxisInfoAnalysis::visitForOpInductionVar(
11111112 ProgramPoint *programPoint = getProgramPointAfter (op);
11121113 auto *lbLattice = getLatticeElementFor (programPoint, op.getLowerBound ());
11131114 auto *stepLattice = getLatticeElementFor (programPoint, op.getStep ());
1114- // If lb or step is not yet ready, skip this operation for now.
1115- if (lbLattice->getValue ().getRank () == 0 ||
1116- stepLattice->getValue ().getRank () == 0 ) {
1117- return ;
1118- }
1115+ for (auto op_iter : {lbLattice, stepLattice})
1116+ if (op_iter->getValue ().getRank () == 0 )
1117+ setToEntryState ((dataflow::Lattice<AxisInfo> *)op_iter);
11191118
11201119 AxisInfo::DimVectorT knownContiguity (1 , 1 );
11211120 AxisInfo::DimVectorT knownDivisibility (1 , 1 );
@@ -1189,15 +1188,24 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
11891188 initPessimisticStateFromFunc (blockArg.getArgNumber (), fun,
11901189 &knownContiguity, &knownDivisibility,
11911190 &knownConstancy);
1192- } else if (isa<gpu::WarpSpecializePartitionsOp>(op)) {
1193- // Initialize the arguments to gpu::WarpSpecializePartitionsOp with
1194- // "unknown" state: the maximum possible divisibility, contiguity, and
1195- // constancy.
1191+ } else if (isa<RegionBranchOpInterface, gpu::WarpSpecializePartitionsOp>(
1192+ op)) {
1193+ // scf::ForOp, scf::IfOp, scf::WhileOp, gpu::WarpSpecializePartitionsOp
1194+ // Control flow operations are initialized with "unknown" state:
1195+ // the maximum possible divisibility, contiguity, and constancy.
11961196 knownDivisibility = DimVectorT (rank, kMaxDivisor );
11971197 knownConstancy = DimVectorT (rank, kMaxDivisor );
11981198 knownContiguity = DimVectorT (rank, kMaxDivisor );
11991199 }
12001200 } else if (Operation *op = value.getDefiningOp ()) {
1201+ if (isa<RegionBranchOpInterface>(op)) {
1202+ // scf::ForOp, scf::IfOp, scf::WhileOp
1203+ // Control flow operations are initialized with "unknown" state:
1204+ // the maximum possible divisibility, contiguity, and constancy.
1205+ knownDivisibility = DimVectorT (rank, kMaxDivisor );
1206+ knownConstancy = DimVectorT (rank, kMaxDivisor );
1207+ knownContiguity = DimVectorT (rank, kMaxDivisor );
1208+ }
12011209 // Other operations are conservatively initialized with the lowest possible
12021210 // divisibility, contiguity, and constancy unless they have specified.
12031211 AxisInfo::initDimVectorFromHint (op->getDiscardableAttr (" tt.divisibility" ),
@@ -1350,10 +1358,6 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp,
13501358 auto *axisInfoMap = getFuncData (funcOp);
13511359 auto updateAxisInfoMap = [&](Value value) {
13521360 auto axisInfo = analysis->getLatticeElement (value)->getValue ();
1353- // If we could not determine the AxisInfo for this value, assume the
1354- // pessimistic state.
1355- if (axisInfo.getRank () == 0 )
1356- axisInfo = AxisInfo::getPessimisticValueState (value);
13571361 AxisInfo curAxisInfo;
13581362 if (axisInfoMap->count (value)) {
13591363 curAxisInfo = AxisInfo::join (axisInfo, axisInfoMap->lookup (value));
0 commit comments