Skip to content

Commit 79bda2c

Browse files
committed
enable infer api with multi-threads
1 parent 418c41d commit 79bda2c

File tree

3 files changed

+33
-20
lines changed

3 files changed

+33
-20
lines changed

paddle/contrib/inference/paddle_inference_api.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class PaddlePredictor {
6363
struct Config;
6464
PaddlePredictor() = default;
6565
PaddlePredictor(const PaddlePredictor&) = delete;
66+
PaddlePredictor& operator=(const PaddlePredictor&) = delete;
6667

6768
// Predict an record.
6869
// The caller should be responsible for allocating and releasing the memory of
@@ -76,7 +77,7 @@ class PaddlePredictor {
7677
virtual std::unique_ptr<PaddlePredictor> Clone() = 0;
7778

7879
// Destroy the Predictor.
79-
virtual ~PaddlePredictor() {}
80+
virtual ~PaddlePredictor() = default;
8081

8182
// The common configs for all the predictors.
8283
struct Config {

paddle/contrib/inference/paddle_inference_api_impl.cc

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,23 @@ std::string num2str(T a) {
5454
}
5555
} // namespace
5656

57-
bool NativePaddlePredictor::Init() {
57+
bool NativePaddlePredictor::Init(std::shared_ptr<framework::Scope> scope) {
5858
VLOG(3) << "Predictor::init()";
5959

6060
if (config_.use_gpu) {
6161
place_ = paddle::platform::CUDAPlace(config_.device);
6262
} else {
6363
place_ = paddle::platform::CPUPlace();
6464
}
65-
paddle::framework::InitDevices(false);
65+
if (scope) {
66+
scope_ = scope;
67+
sub_scope_ = &(scope->NewScope());
68+
} else {
69+
paddle::framework::InitDevices(false);
70+
scope_.reset(new paddle::framework::Scope());
71+
}
72+
6673
executor_.reset(new paddle::framework::Executor(place_));
67-
scope_.reset(new paddle::framework::Scope());
6874

6975
// Initialize the inference program
7076
if (!config_.model_dir.empty()) {
@@ -83,20 +89,22 @@ bool NativePaddlePredictor::Init() {
8389
return false;
8490
}
8591
ctx_ = executor_->Prepare(*inference_program_, 0);
86-
87-
// Create temporary variables first, so that the first batch do not need to
88-
// create variables in the runtime. This is the logics of the old inference
89-
// API.
90-
// TODO(Superjomn) this should be modified when `Clone` is valid for
91-
// multi-thread application.
92-
executor_->CreateVariables(*inference_program_, scope_.get(), 0);
92+
executor_->CreateVariables(
93+
*inference_program_, sub_scope_ ? sub_scope_ : scope_.get(), 0);
9394

9495
// Get the feed_target_names and fetch_target_names
9596
feed_target_names_ = inference_program_->GetFeedTargetNames();
9697
fetch_target_names_ = inference_program_->GetFetchTargetNames();
9798
return true;
9899
}
99100

101+
NativePaddlePredictor::~NativePaddlePredictor() {
102+
if (sub_scope_) {
103+
PADDLE_ENFORCE_NOT_NULL(scope_, "Should have parent scope!");
104+
scope_->DeleteScope(sub_scope_);
105+
}
106+
};
107+
100108
bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
101109
std::vector<PaddleTensor> *output_data) {
102110
VLOG(3) << "Predictor::predict";
@@ -121,11 +129,12 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
121129
}
122130
// Run the inference program
123131
// if share variables, we need not create variables
124-
executor_->RunPreparedContext(ctx_.get(),
125-
scope_.get(),
126-
&feed_targets,
127-
&fetch_targets,
128-
false /* don't create variable eatch time */);
132+
executor_->RunPreparedContext(
133+
ctx_.get(),
134+
sub_scope_ != nullptr ? sub_scope_ : scope_.get(),
135+
&feed_targets,
136+
&fetch_targets,
137+
false /* don't create variable eatch time */);
129138
if (!GetFetch(fetchs, output_data)) {
130139
LOG(ERROR) << "fail to get fetchs";
131140
return false;
@@ -138,7 +147,7 @@ std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() {
138147
VLOG(3) << "Predictor::clone";
139148
std::unique_ptr<PaddlePredictor> cls(new NativePaddlePredictor(config_));
140149

141-
if (!dynamic_cast<NativePaddlePredictor *>(cls.get())->Init()) {
150+
if (!dynamic_cast<NativePaddlePredictor *>(cls.get())->Init(scope_)) {
142151
LOG(ERROR) << "fail to call Init";
143152
return nullptr;
144153
}

paddle/contrib/inference/paddle_inference_api_impl.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@ class NativePaddlePredictor : public PaddlePredictor {
3434
explicit NativePaddlePredictor(const NativeConfig &config)
3535
: config_(config) {}
3636

37-
bool Init();
37+
// will only create sub scope if have global scope
38+
bool Init(std::shared_ptr<framework::Scope> scope = nullptr);
3839

3940
bool Run(const std::vector<PaddleTensor> &inputs,
4041
std::vector<PaddleTensor> *output_data) override;
4142

4243
std::unique_ptr<PaddlePredictor> Clone() override;
4344

44-
~NativePaddlePredictor() override{};
45+
~NativePaddlePredictor() override;
4546

4647
private:
4748
bool SetFeed(const std::vector<PaddleTensor> &input_datas,
@@ -52,11 +53,13 @@ class NativePaddlePredictor : public PaddlePredictor {
5253
NativeConfig config_;
5354
platform::Place place_;
5455
std::unique_ptr<framework::Executor> executor_;
55-
std::unique_ptr<framework::Scope> scope_;
56+
std::shared_ptr<framework::Scope> scope_;
5657
std::unique_ptr<framework::ExecutorPrepareContext> ctx_;
5758
std::unique_ptr<framework::ProgramDesc> inference_program_;
5859
std::vector<std::string> feed_target_names_;
5960
std::vector<std::string> fetch_target_names_;
61+
// Do not use unique_ptr, use parent scope to delete
62+
framework::Scope *sub_scope_{nullptr};
6063
};
6164

6265
} // namespace paddle

0 commit comments

Comments
 (0)