@@ -195,9 +195,9 @@ class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis<
195195 dataflow::Lattice<AxisInfo>>::getLatticeElement;
196196 using FuncAxisInfoMapT = DenseMap<FunctionOpInterface, AxisInfo>;
197197
198- void visitOperation (Operation *op,
199- ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
200- ArrayRef<dataflow::Lattice<AxisInfo> *> results) override ;
198+ LogicalResult visitOperation (
199+ Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
200+ ArrayRef<dataflow::Lattice<AxisInfo> *> results) override ;
201201 void
202202 visitForOpInductionVar (scf::ForOp op,
203203 ArrayRef<dataflow::Lattice<AxisInfo> *> argLattices);
@@ -1039,7 +1039,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10391039 visitors.append <LoadOpAxisInfoVisitor>();
10401040}
10411041
1042- void AxisInfoAnalysis::visitOperation (
1042+ LogicalResult AxisInfoAnalysis::visitOperation (
10431043 Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
10441044 ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
10451045 // TODO: For sure not the right way to do this
@@ -1048,8 +1048,10 @@ void AxisInfoAnalysis::visitOperation(
10481048 if (op->getValue ().getRank () == 0 )
10491049 setToEntryState ((dataflow::Lattice<AxisInfo> *)op);
10501050 AxisInfo curr = visitors.apply (op, operands);
1051- if (curr.getRank () == 0 )
1052- return setAllToEntryStates (results);
1051+ if (curr.getRank () == 0 ) {
1052+ setAllToEntryStates (results);
1053+ return mlir::success ();
1054+ }
10531055 // override with hint
10541056 auto newContiguity = curr.getContiguity ();
10551057 auto newDivisibility = curr.getDivisibility ();
@@ -1071,6 +1073,7 @@ void AxisInfoAnalysis::visitOperation(
10711073 // join all lattice elements
10721074 for (auto *result : results)
10731075 propagateIfChanged (result, result->join (curr));
1076+ return mlir::success ();
10741077}
10751078
10761079void AxisInfoAnalysis::visitForOpInductionVar (
0 commit comments