@@ -1079,11 +1079,10 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver,
10791079LogicalResult AxisInfoAnalysis::visitOperation (
10801080 Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
10811081 ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
1082- // TODO: For sure not the right way to do this
1083- // but why is scf.if not initialized otherwise?
1082+ // If any operands are not yet ready, skip this operation for now.
10841083 for (auto op : operands)
10851084 if (op->getValue ().getRank () == 0 )
1086- setToEntryState ((dataflow::Lattice<AxisInfo> *)op );
1085+ return success ( );
10871086 AxisInfo curr = visitors.apply (op, operands);
10881087 if (curr.getRank () == 0 ) {
10891088 setAllToEntryStates (results);
@@ -1112,9 +1111,11 @@ void AxisInfoAnalysis::visitForOpInductionVar(
11121111 ProgramPoint *programPoint = getProgramPointAfter (op);
11131112 auto *lbLattice = getLatticeElementFor (programPoint, op.getLowerBound ());
11141113 auto *stepLattice = getLatticeElementFor (programPoint, op.getStep ());
1115- for (auto op_iter : {lbLattice, stepLattice})
1116- if (op_iter->getValue ().getRank () == 0 )
1117- setToEntryState ((dataflow::Lattice<AxisInfo> *)op_iter);
1114+ // If lb or step is not yet ready, skip this operation for now.
1115+ if (lbLattice->getValue ().getRank () == 0 ||
1116+ stepLattice->getValue ().getRank () == 0 ) {
1117+ return ;
1118+ }
11181119
11191120 AxisInfo::DimVectorT knownContiguity (1 , 1 );
11201121 AxisInfo::DimVectorT knownDivisibility (1 , 1 );
@@ -1188,24 +1189,15 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
11881189 initPessimisticStateFromFunc (blockArg.getArgNumber (), fun,
11891190 &knownContiguity, &knownDivisibility,
11901191 &knownConstancy);
1191- } else if (isa<RegionBranchOpInterface, gpu::WarpSpecializePartitionsOp>(
1192- op)) {
1193- // scf::ForOp, scf::IfOp, scf::WhileOp, gpu::WarpSpecializePartitionsOp
1194- // Control flow operations are initialized with "unknown" state:
1195- // the maximum possible divisibility, contiguity, and constancy.
1192+ } else if (isa<gpu::WarpSpecializePartitionsOp>(op)) {
1193+ // Initialize the arguments to gpu::WarpSpecializePartitionsOp with
1194+ // "unknown" state: the maximum possible divisibility, contiguity, and
1195+ // constancy.
11961196 knownDivisibility = DimVectorT (rank, kMaxDivisor );
11971197 knownConstancy = DimVectorT (rank, kMaxDivisor );
11981198 knownContiguity = DimVectorT (rank, kMaxDivisor );
11991199 }
12001200 } else if (Operation *op = value.getDefiningOp ()) {
1201- if (isa<RegionBranchOpInterface>(op)) {
1202- // scf::ForOp, scf::IfOp, scf::WhileOp
1203- // Control flow operations are initialized with "unknown" state:
1204- // the maximum possible divisibility, contiguity, and constancy.
1205- knownDivisibility = DimVectorT (rank, kMaxDivisor );
1206- knownConstancy = DimVectorT (rank, kMaxDivisor );
1207- knownContiguity = DimVectorT (rank, kMaxDivisor );
1208- }
12091201 // Other operations are conservatively initialized with the lowest possible
12101202 // divisibility, contiguity, and constancy unless they have specified.
12111203 AxisInfo::initDimVectorFromHint (op->getDiscardableAttr (" tt.divisibility" ),
@@ -1358,6 +1350,10 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp,
13581350 auto *axisInfoMap = getFuncData (funcOp);
13591351 auto updateAxisInfoMap = [&](Value value) {
13601352 auto axisInfo = analysis->getLatticeElement (value)->getValue ();
1353+ // If we could not determine the AxisInfo for this value, assume the
1354+ // pessimistic state.
1355+ if (axisInfo.getRank () == 0 )
1356+ axisInfo = AxisInfo::getPessimisticValueState (value);
13611357 auto &valInfo = (*axisInfoMap)[value];
13621358 valInfo = AxisInfo::join (axisInfo, valInfo);
13631359 };
0 commit comments