Skip to content

Commit a385803

Browse files
authored
Merge pull request #11162 from tensor-tang/infer/api
enable infer api with multi-threads
2 parents 6ada5f4 + 3659401 commit a385803

File tree

4 files changed

+38
-21
lines changed

4 files changed

+38
-21
lines changed

paddle/contrib/inference/demo/simple_on_word2vec.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ void Main(bool use_gpu) {
6565
}
6666

6767
TEST(demo, word2vec_cpu) { Main(false /*use_gpu*/); }
68+
69+
#ifdef PADDLE_WITH_CUDA
6870
TEST(demo, word2vec_gpu) { Main(true /*use_gpu*/); }
71+
#endif
6972

7073
} // namespace demo
7174
} // namespace paddle

paddle/contrib/inference/paddle_inference_api.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class PaddlePredictor {
6363
struct Config;
6464
PaddlePredictor() = default;
6565
PaddlePredictor(const PaddlePredictor&) = delete;
66+
PaddlePredictor& operator=(const PaddlePredictor&) = delete;
6667

6768
// Predict an record.
6869
// The caller should be responsible for allocating and releasing the memory of
@@ -76,7 +77,7 @@ class PaddlePredictor {
7677
virtual std::unique_ptr<PaddlePredictor> Clone() = 0;
7778

7879
// Destroy the Predictor.
79-
virtual ~PaddlePredictor() {}
80+
virtual ~PaddlePredictor() = default;
8081

8182
// The common configs for all the predictors.
8283
struct Config {

paddle/contrib/inference/paddle_inference_api_impl.cc

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,24 @@ std::string num2str(T a) {
5454
}
5555
} // namespace
5656

57-
bool NativePaddlePredictor::Init() {
57+
bool NativePaddlePredictor::Init(
58+
std::shared_ptr<framework::Scope> parent_scope) {
5859
VLOG(3) << "Predictor::init()";
5960

6061
if (config_.use_gpu) {
6162
place_ = paddle::platform::CUDAPlace(config_.device);
6263
} else {
6364
place_ = paddle::platform::CPUPlace();
6465
}
65-
paddle::framework::InitDevices(false);
66+
if (parent_scope) {
67+
scope_ = parent_scope;
68+
sub_scope_ = &(parent_scope->NewScope());
69+
} else {
70+
paddle::framework::InitDevices(false);
71+
scope_.reset(new paddle::framework::Scope());
72+
}
73+
6674
executor_.reset(new paddle::framework::Executor(place_));
67-
scope_.reset(new paddle::framework::Scope());
6875

6976
// Initialize the inference program
7077
if (!config_.model_dir.empty()) {
@@ -83,20 +90,22 @@ bool NativePaddlePredictor::Init() {
8390
return false;
8491
}
8592
ctx_ = executor_->Prepare(*inference_program_, 0);
86-
87-
// Create temporary variables first, so that the first batch do not need to
88-
// create variables in the runtime. This is the logics of the old inference
89-
// API.
90-
// TODO(Superjomn) this should be modified when `Clone` is valid for
91-
// multi-thread application.
92-
executor_->CreateVariables(*inference_program_, scope_.get(), 0);
93+
executor_->CreateVariables(
94+
*inference_program_, sub_scope_ ? sub_scope_ : scope_.get(), 0);
9395

9496
// Get the feed_target_names and fetch_target_names
9597
feed_target_names_ = inference_program_->GetFeedTargetNames();
9698
fetch_target_names_ = inference_program_->GetFetchTargetNames();
9799
return true;
98100
}
99101

102+
NativePaddlePredictor::~NativePaddlePredictor() {
103+
if (sub_scope_) {
104+
PADDLE_ENFORCE_NOT_NULL(scope_, "Should have parent scope!");
105+
scope_->DeleteScope(sub_scope_);
106+
}
107+
};
108+
100109
bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
101110
std::vector<PaddleTensor> *output_data) {
102111
VLOG(3) << "Predictor::predict";
@@ -121,11 +130,12 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
121130
}
122131
// Run the inference program
123132
// if share variables, we need not create variables
124-
executor_->RunPreparedContext(ctx_.get(),
125-
scope_.get(),
126-
&feed_targets,
127-
&fetch_targets,
128-
false /* don't create variable eatch time */);
133+
executor_->RunPreparedContext(
134+
ctx_.get(),
135+
sub_scope_ != nullptr ? sub_scope_ : scope_.get(),
136+
&feed_targets,
137+
&fetch_targets,
138+
false /* don't create variable eatch time */);
129139
if (!GetFetch(fetchs, output_data)) {
130140
LOG(ERROR) << "fail to get fetchs";
131141
return false;
@@ -138,7 +148,7 @@ std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() {
138148
VLOG(3) << "Predictor::clone";
139149
std::unique_ptr<PaddlePredictor> cls(new NativePaddlePredictor(config_));
140150

141-
if (!dynamic_cast<NativePaddlePredictor *>(cls.get())->Init()) {
151+
if (!dynamic_cast<NativePaddlePredictor *>(cls.get())->Init(scope_)) {
142152
LOG(ERROR) << "fail to call Init";
143153
return nullptr;
144154
}
@@ -266,7 +276,7 @@ CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(
266276
}
267277

268278
std::unique_ptr<PaddlePredictor> predictor(new NativePaddlePredictor(config));
269-
if (!dynamic_cast<NativePaddlePredictor *>(predictor.get())->Init()) {
279+
if (!dynamic_cast<NativePaddlePredictor *>(predictor.get())->Init(nullptr)) {
270280
return nullptr;
271281
}
272282
return std::move(predictor);

paddle/contrib/inference/paddle_inference_api_impl.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@ class NativePaddlePredictor : public PaddlePredictor {
3434
explicit NativePaddlePredictor(const NativeConfig &config)
3535
: config_(config) {}
3636

37-
bool Init();
37+
// will only create sub scope if have global scope
38+
bool Init(std::shared_ptr<framework::Scope> parent_scope);
3839

3940
bool Run(const std::vector<PaddleTensor> &inputs,
4041
std::vector<PaddleTensor> *output_data) override;
4142

4243
std::unique_ptr<PaddlePredictor> Clone() override;
4344

44-
~NativePaddlePredictor() override{};
45+
~NativePaddlePredictor() override;
4546

4647
private:
4748
bool SetFeed(const std::vector<PaddleTensor> &input_datas,
@@ -52,11 +53,13 @@ class NativePaddlePredictor : public PaddlePredictor {
5253
NativeConfig config_;
5354
platform::Place place_;
5455
std::unique_ptr<framework::Executor> executor_;
55-
std::unique_ptr<framework::Scope> scope_;
56+
std::shared_ptr<framework::Scope> scope_;
5657
std::unique_ptr<framework::ExecutorPrepareContext> ctx_;
5758
std::unique_ptr<framework::ProgramDesc> inference_program_;
5859
std::vector<std::string> feed_target_names_;
5960
std::vector<std::string> fetch_target_names_;
61+
// Do not use unique_ptr, use parent scope to delete
62+
framework::Scope *sub_scope_{nullptr};
6063
};
6164

6265
} // namespace paddle

0 commit comments

Comments
 (0)