Skip to content

Commit 526790e

Browse files
authored
infer get program (#15511)
1 parent 3c224e7 commit 526790e

File tree

4 files changed

+16
-0
lines changed

4 files changed

+16
-0
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,10 @@ bool AnalysisPredictor::need_collect_var_shapes_for_memory_optim() {
726726
return need;
727727
}
728728

729+
std::string AnalysisPredictor::GetSeriazlizedProgram() const {
730+
return inference_program_->Proto()->SerializeAsString();
731+
}
732+
729733
template <>
730734
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<contrib::AnalysisConfig>(
731735
const contrib::AnalysisConfig &config) {

paddle/fluid/inference/api/analysis_predictor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ class AnalysisPredictor : public PaddlePredictor {
7575

7676
void SetMkldnnThreadID(int tid);
7777

78+
std::string GetSeriazlizedProgram() const override;
79+
7880
protected:
7981
// For memory optimization.
8082
bool need_collect_var_shapes_for_memory_optim();

paddle/fluid/inference/api/analysis_predictor_tester.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ TEST(AnalysisPredictor, memory_optim) {
215215
{
216216
// The first predictor help to cache the memory optimize strategy.
217217
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
218+
LOG(INFO) << "serialized program: " << predictor->GetSeriazlizedProgram();
219+
ASSERT_FALSE(predictor->GetSeriazlizedProgram().empty());
218220

219221
// Run several times to check the parameters are not reused by mistake.
220222
for (int i = 0; i < 5; i++) {

paddle/fluid/inference/api/paddle_api.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,14 @@ class PaddlePredictor {
215215
*/
216216
virtual ~PaddlePredictor() = default;
217217

218+
/** \brief Get the serialized model program that executes in inference phase.
219+
* Its data type is ProgramDesc, which is a protobuf message.
220+
*/
221+
virtual std::string GetSeriazlizedProgram() const {
222+
assert(false); // Force raise error.
223+
return "NotImplemented";
224+
};
225+
218226
/** The common configs for all the predictors.
219227
*/
220228
struct Config {

0 commit comments

Comments
 (0)