Skip to content

Commit 30defc0

Browse files
authored
add value_replaced_hook in inference pass for cinn (#69888)
1 parent a2d9152 commit 30defc0

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
#include "paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.h"
103103
#include "paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_util.h"
104104
#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h"
105+
#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h"
105106
#endif
106107

107108
#include "paddle/common/flags.h"
@@ -843,6 +844,11 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
843844
std::make_unique<pir::PassManager::IRPrinterOption>(
844845
ir_printing_conditions, ir_printing_conditions));
845846
}
847+
auto &shape_analysis =
848+
pir::ShapeAnalysisManager::Instance().Get(pir_program_.get());
849+
pass_manager->SetValueReplacedHook([&](pir::Value from, pir::Value to) {
850+
shape_analysis.ShareShapeOrData(from, to);
851+
});
846852
return pass_manager;
847853
};
848854

0 commit comments

Comments
 (0)