@@ -1165,10 +1165,63 @@ SetVector<Operation *> multiRootGetSlice(Operation *op,
11651165 return multiRootTopologicalSort (slice);
11661166}
11671167
1168+ namespace {
1169+ // Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis
1170+ // interacts with constant propagation, but SparseConstantPropagation
1171+ // doesn't seem to be sufficient.
1172+ class ConstantAnalysis : public DataFlowAnalysis {
1173+ public:
1174+ using DataFlowAnalysis::DataFlowAnalysis;
1175+
1176+ LogicalResult initialize (Operation *top) override {
1177+ WalkResult result = top->walk ([&](Operation *op) {
1178+ ProgramPoint programPoint (op);
1179+ if (failed (visit (&programPoint)))
1180+ return WalkResult::interrupt ();
1181+ return WalkResult::advance ();
1182+ });
1183+ return success (!result.wasInterrupted ());
1184+ }
1185+
1186+ LogicalResult visit (ProgramPoint *point) override {
1187+ Operation *op = point->getOperation ();
1188+ Attribute value;
1189+ if (matchPattern (op, m_Constant (&value))) {
1190+ auto *constant = getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(
1191+ op->getResult (0 ));
1192+ propagateIfChanged (constant, constant->join (dataflow::ConstantValue (
1193+ value, op->getDialect ())));
1194+ return success ();
1195+ }
1196+ // Dead code analysis requires every operands has initialized ConstantValue
1197+ // state before it is visited.
1198+ // https://github.com/llvm/llvm-project/blob/2ec1aba2b69faa1de5f71832a48e25aa3b5d5314/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp#L322
1199+ // That's why we need to set all operands to unknown constants.
1200+ setAllToUnknownConstants (op->getResults ());
1201+ for (Region ®ion : op->getRegions ()) {
1202+ for (Block &block : region.getBlocks ())
1203+ setAllToUnknownConstants (block.getArguments ());
1204+ }
1205+ return success ();
1206+ }
1207+
1208+ private:
1209+ // / Set all given values as not constants.
1210+ void setAllToUnknownConstants (ValueRange values) {
1211+ dataflow::ConstantValue unknownConstant (nullptr , nullptr );
1212+ for (Value value : values) {
1213+ auto *constant =
1214+ getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(value);
1215+ propagateIfChanged (constant, constant->join (unknownConstant));
1216+ }
1217+ }
1218+ };
1219+ } // namespace
1220+
11681221std::unique_ptr<DataFlowSolver> createDataFlowSolver () {
11691222 auto solver = std::make_unique<DataFlowSolver>();
11701223 solver->load <dataflow::DeadCodeAnalysis>();
1171- solver->load <dataflow::SparseConstantPropagation >();
1224+ solver->load <ConstantAnalysis >();
11721225 return solver;
11731226}
11741227
0 commit comments