Skip to content

Commit c13efe0

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_tensorrt_elementwise_add
2 parents b241a47 + baff71d commit c13efe0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1010
-365
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/platform/enforce.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace details {
22+
23+
class ExceptionHolder {
24+
public:
25+
void Catch(const platform::EnforceNotMet& exp) {
26+
std::lock_guard<std::mutex> lock(mu_);
27+
exception_.reset(new platform::EnforceNotMet(exp));
28+
type_ = kEnforceNotMet;
29+
}
30+
31+
void Catch(const platform::EOFException& exp) {
32+
std::lock_guard<std::mutex> lock(mu_);
33+
// EOFException will not cover up existing EnforceNotMet.
34+
if (exception_.get() == nullptr) {
35+
exception_.reset(new platform::EOFException(exp));
36+
type_ = kEOF;
37+
}
38+
}
39+
40+
bool ExceptionCatched() const {
41+
std::lock_guard<std::mutex> lock(mu_);
42+
return exception_.get() != nullptr;
43+
}
44+
45+
void Throw() {
46+
std::lock_guard<std::mutex> lock(mu_);
47+
switch (type_) {
48+
case kNone:
49+
break;
50+
case kEnforceNotMet: {
51+
auto e = *static_cast<platform::EnforceNotMet*>(exception_.get());
52+
throw e;
53+
break;
54+
}
55+
case kEOF: {
56+
auto e = *static_cast<platform::EOFException*>(exception_.get());
57+
throw e;
58+
break;
59+
}
60+
default:
61+
LOG(FATAL) << "Unknown exception.";
62+
}
63+
exception_.reset();
64+
type_ = kNone;
65+
}
66+
67+
void Clear() {
68+
std::lock_guard<std::mutex> lock(mu_);
69+
exception_.reset();
70+
type_ = kNone;
71+
}
72+
73+
private:
74+
enum ExceptionType { kNone, kEnforceNotMet, kEOF };
75+
ExceptionType type_{kNone};
76+
77+
std::unique_ptr<std::exception> exception_;
78+
mutable std::mutex mu_;
79+
};
80+
81+
} // namespace details
82+
} // namespace framework
83+
} // namespace paddle

paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
4141
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
4242
std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
4343

44-
const ir::Graph& Graph() const { return underlying_executor_->Graph(); }
44+
const ir::Graph& Graph() const override {
45+
return underlying_executor_->Graph();
46+
}
4547

4648
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
4749

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
8383

8484
// Clean run context
8585
run_op_futures_.clear();
86-
exception_.reset();
86+
exception_holder_.Clear();
8787

8888
// Step 3. Execution
8989
while (!pending_vars.empty()) {
@@ -103,23 +103,11 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
103103
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
104104

105105
if (timeout) {
106-
std::unique_lock<std::mutex> l(exception_mu_);
107-
if (exception_) {
108-
l.unlock();
106+
if (exception_holder_.ExceptionCatched()) {
109107
for (auto &run_op_future : run_op_futures_) {
110108
run_op_future.wait();
111109
}
112-
l.lock();
113-
std::exception *exp = exception_.get();
114-
if (dynamic_cast<platform::EOFException *>(exp)) {
115-
auto e = *static_cast<platform::EOFException *>(exp);
116-
throw e;
117-
} else if (dynamic_cast<platform::EnforceNotMet *>(exp)) {
118-
auto e = *static_cast<platform::EnforceNotMet *>(exp);
119-
throw e;
120-
} else {
121-
LOG(FATAL) << "Unknown exception.";
122-
}
110+
exception_holder_.Throw();
123111
} else {
124112
continue;
125113
}
@@ -229,14 +217,9 @@ void ThreadedSSAGraphExecutor::RunOp(
229217
ready_var_q->Extend(op->Outputs());
230218
VLOG(10) << op << " " << op->Name() << "Signal posted";
231219
} catch (platform::EOFException ex) {
232-
std::lock_guard<std::mutex> l(exception_mu_);
233-
// EOFException will not cover up existing EnforceNotMet.
234-
if (exception_.get() == nullptr) {
235-
exception_.reset(new platform::EOFException(ex));
236-
}
220+
exception_holder_.Catch(ex);
237221
} catch (platform::EnforceNotMet ex) {
238-
std::lock_guard<std::mutex> l(exception_mu_);
239-
exception_.reset(new platform::EnforceNotMet(ex));
222+
exception_holder_.Catch(ex);
240223
} catch (...) {
241224
LOG(FATAL) << "Unknown exception catched";
242225
}

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <functional>
2525
#include "ThreadPool.h" // ThreadPool in thrird party
2626
#include "paddle/fluid/framework/blocking_queue.h"
27+
#include "paddle/fluid/framework/details/exception_holder.h"
2728
#include "paddle/fluid/framework/details/execution_strategy.h"
2829
#include "paddle/fluid/framework/details/fetch_op_handle.h"
2930
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
@@ -42,7 +43,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
4243
const std::vector<platform::Place> &places,
4344
std::unique_ptr<ir::Graph> &&graph);
4445

45-
const ir::Graph &Graph() const { return *graph_; }
46+
const ir::Graph &Graph() const override { return *graph_; }
4647
// Run a SSAGraph by a thread pool
4748
// Use topological sort algorithm
4849
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
@@ -59,8 +60,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
5960
std::vector<Scope *> local_scopes_;
6061
std::vector<platform::Place> places_;
6162
platform::DeviceContextPool fetch_ctxs_;
62-
std::mutex exception_mu_;
63-
std::unique_ptr<std::exception> exception_;
63+
ExceptionHolder exception_holder_;
6464
std::atomic<int> running_ops_;
6565

6666
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,

paddle/fluid/inference/analysis/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph
66
tensorrt_subgraph_node_mark_pass.cc
77
analyzer.cc
88
helper.cc
9+
model_store_pass.cc
910
DEPS framework_proto proto_desc)
1011
cc_test(test_node SRCS node_tester.cc DEPS analysis)
1112
cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
13+
cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis)
1214

1315
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
1416

@@ -40,3 +42,4 @@ inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_
4042
inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc)
4143
inference_analysis_test(test_tensorrt_subgraph_node_mark_pass SRCS tensorrt_subgraph_node_mark_pass_tester.cc)
4244
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc)
45+
inference_analysis_test(test_model_store_pass SRCS model_store_pass_tester.cc)

paddle/fluid/inference/analysis/analyzer.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
1818
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
1919
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
20+
#include "paddle/fluid/inference/analysis/model_store_pass.h"
2021
#include "paddle/fluid/inference/analysis/pass_manager.h"
2122
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
2223
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
@@ -29,6 +30,9 @@ DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, false,
2930
DEFINE_string(inference_analysis_graphviz_log_root, "./",
3031
"Graphviz debuger for data flow graphs.");
3132

33+
DEFINE_string(inference_analysis_output_storage_path, "",
34+
"optimized model output path");
35+
3236
namespace inference {
3337
namespace analysis {
3438

@@ -47,6 +51,9 @@ class DfgPassManagerImpl final : public DfgPassManager {
4751
AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller));
4852
}
4953
AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass);
54+
if (!FLAGS_inference_analysis_output_storage_path.empty()) {
55+
AddPass("model-store-pass", new ModelStorePass);
56+
}
5057
}
5158

5259
std::string repr() const override { return "dfg-pass-manager"; }

paddle/fluid/inference/analysis/analyzer.h

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,23 @@ limitations under the License. */
1616

1717
/*
1818
* This file contains Analyzer, an class that exposed as a library that analyze
19-
* and optimize
20-
* Fluid ProgramDesc for inference. Similar to LLVM, it has multiple flags to
21-
* control whether
22-
* an process is applied on the program.
19+
* and optimize Fluid ProgramDesc for inference. Similar to LLVM, it has
20+
* multiple flags to
21+
* control whether an process is applied on the program.
2322
*
2423
* The processes are called Passes in analysis, the Passes are placed in a
25-
* pipeline, the first
26-
* Pass is the FluidToDataFlowGraphPass which transforms a Fluid ProgramDesc to
27-
* a data flow
28-
* graph, the last Pass is DataFlowGraphToFluidPass which transforms a data flow
29-
* graph to a
30-
* Fluid ProgramDesc. The passes in the middle of the pipeline can be any Passes
31-
* which take a
32-
* node or data flow graph as input.
24+
* pipeline, the first Pass is the FluidToDataFlowGraphPass which transforms a
25+
* Fluid ProgramDesc to
26+
* a data flow graph, the last Pass is DataFlowGraphToFluidPass which transforms
27+
* a data flow graph to a Fluid ProgramDesc. The passes in the middle of the
28+
* pipeline can be any Passes
29+
* which take a node or data flow graph as input.
3330
*
3431
* The Analyzer can be used in two methods, the first is a executable file which
35-
* can be used to
36-
* pre-process the inference model and can be controlled by passing difference
37-
* command flags;
32+
* can be used to pre-process the inference model and can be controlled by
33+
* passing difference command flags;
3834
* the other way is to compose inside the inference API as a runtime pre-process
39-
* phase in the
40-
* inference service.
35+
* phase in the inference service.
4136
*/
4237

4338
#include <gflags/gflags.h>
@@ -50,6 +45,7 @@ namespace paddle {
5045
// flag if not available.
5146
DECLARE_bool(inference_analysis_enable_tensorrt_subgraph_engine);
5247
DECLARE_string(inference_analysis_graphviz_log_root);
48+
DECLARE_string(inference_analysis_output_storage_path);
5349

5450
namespace inference {
5551
namespace analysis {
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
/*
16+
* This file implements analysizer -- an executation help to analyze and
17+
* optimize trained model.
18+
*/
19+
#include "paddle/fluid/inference/analysis/analyzer.h"
20+
#include <gflags/gflags.h>
21+
#include <glog/logging.h>
22+
23+
int main(int argc, char** argv) {
24+
google::ParseCommandLineFlags(&argc, &argv, true);
25+
using paddle::inference::analysis::Analyzer;
26+
using paddle::inference::analysis::Argument;
27+
28+
Argument argument;
29+
Analyzer analyzer;
30+
analyzer.Run(&argument);
31+
32+
return 0;
33+
}

paddle/fluid/inference/analysis/analyzer_tester.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,18 @@ namespace paddle {
2020
namespace inference {
2121
namespace analysis {
2222

23-
TEST_F(DFG_Tester, analysis_without_tensorrt) {
23+
TEST(Analyzer, analysis_without_tensorrt) {
2424
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = false;
25+
Argument argument;
26+
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
2527
Analyzer analyser;
2628
analyser.Run(&argument);
2729
}
2830

29-
TEST_F(DFG_Tester, analysis_with_tensorrt) {
31+
TEST(Analyzer, analysis_with_tensorrt) {
3032
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = true;
33+
Argument argument;
34+
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
3135
Analyzer analyser;
3236
analyser.Run(&argument);
3337
}

paddle/fluid/inference/analysis/argument.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@ namespace analysis {
3636
* All the fields should be registered here for clearness.
3737
*/
3838
struct Argument {
39+
Argument() = default;
40+
explicit Argument(const std::string& fluid_model_dir)
41+
: fluid_model_dir(new std::string(fluid_model_dir)) {}
42+
// The directory of the trained model.
43+
std::unique_ptr<std::string> fluid_model_dir;
44+
// The path of `__model__` and `param`, this is used when the file name of
45+
// model and param is changed.
46+
std::unique_ptr<std::string> fluid_model_program_path;
47+
std::unique_ptr<std::string> fluid_model_param_path;
48+
3949
// The graph that process by the Passes or PassManagers.
4050
std::unique_ptr<DataFlowGraph> main_dfg;
4151

@@ -44,6 +54,9 @@ struct Argument {
4454

4555
// The processed program desc.
4656
std::unique_ptr<framework::proto::ProgramDesc> transformed_program_desc;
57+
58+
// The output storage path of ModelStorePass.
59+
std::unique_ptr<std::string> model_output_store_path;
4760
};
4861

4962
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)

0 commit comments

Comments
 (0)