Skip to content

Commit c73977a

Browse files
committed
Merge branch 'develop' into trt
2 parents e8e8ad0 + ca2d6d3 commit c73977a

21 files changed

+442
-96
lines changed

paddle/fluid/framework/scope.cc

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,7 @@ DEFINE_bool(
3434
namespace paddle {
3535
namespace framework {
3636

37-
Scope::~Scope() {
38-
DropKids();
39-
for (auto& kv : vars_) {
40-
VLOG(3) << "Destroy variable " << kv.first;
41-
delete kv.second;
42-
}
43-
}
37+
Scope::~Scope() { DropKids(); }
4438

4539
Scope& Scope::NewScope() const {
4640
std::unique_lock<std::mutex> lock(mutex_);
@@ -49,10 +43,13 @@ Scope& Scope::NewScope() const {
4943
}
5044

5145
Variable* Scope::Var(const std::string& name) {
46+
// acquire the lock when new var under this scope
47+
std::unique_lock<std::mutex> lock(mutex_);
5248
auto* v = FindVarLocally(name);
5349
if (v != nullptr) return v;
50+
5451
v = new Variable();
55-
vars_[name] = v;
52+
vars_[name].reset(v);
5653
VLOG(3) << "Create variable " << name;
5754
v->name_ = &(vars_.find(name)->first);
5855
return v;
@@ -67,22 +64,29 @@ Variable* Scope::Var(std::string* name) {
6764
}
6865

6966
Variable* Scope::FindVar(const std::string& name) const {
67+
// acquire the lock when find var
68+
std::unique_lock<std::mutex> lock(mutex_);
69+
return FindVarInternal(name);
70+
}
71+
72+
Variable* Scope::FindVarInternal(const std::string& name) const {
7073
auto var = FindVarLocally(name);
7174
if (var != nullptr) {
7275
return var;
7376
}
74-
return (parent_ == nullptr) ? nullptr : parent_->FindVar(name);
77+
return (parent_ == nullptr) ? nullptr : parent_->FindVarInternal(name);
7578
}
7679

7780
const Scope* Scope::FindScope(const Variable* var) const {
7881
for (auto& kv : vars_) {
79-
if (kv.second == var) {
82+
if (kv.second.get() == var) {
8083
return this;
8184
}
8285
}
8386
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
8487
}
8588
void Scope::DropKids() {
89+
std::unique_lock<std::mutex> lock(mutex_);
8690
for (Scope* s : kids_) delete s;
8791
kids_.clear();
8892
}
@@ -110,10 +114,10 @@ void Scope::DeleteScope(Scope* scope) const {
110114
}
111115

112116
void Scope::EraseVars(const std::vector<std::string>& var_names) {
117+
std::unique_lock<std::mutex> lock(mutex_);
113118
std::set<std::string> var_set(var_names.begin(), var_names.end());
114119
for (auto it = vars_.begin(); it != vars_.end();) {
115120
if (var_set.find(it->first) != var_set.end()) {
116-
delete it->second;
117121
it = vars_.erase(it);
118122
} else {
119123
++it;
@@ -129,7 +133,7 @@ void Scope::Rename(const std::string& origin_name,
129133
auto new_it = vars_.find(new_name);
130134
PADDLE_ENFORCE(new_it == vars_.end(),
131135
"The variable with name %s is already in the scope", new_name);
132-
vars_[new_name] = origin_it->second;
136+
vars_[new_name].reset(origin_it->second.release());
133137
vars_.erase(origin_it);
134138
}
135139

@@ -141,7 +145,7 @@ std::string Scope::Rename(const std::string& origin_name) const {
141145

142146
Variable* Scope::FindVarLocally(const std::string& name) const {
143147
auto it = vars_.find(name);
144-
if (it != vars_.end()) return it->second;
148+
if (it != vars_.end()) return it->second.get();
145149
return nullptr;
146150
}
147151

paddle/fluid/framework/scope.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,18 @@ class Scope {
4747
Scope& NewScope() const;
4848

4949
/// Create a variable with given name if it doesn't exist.
50+
/// Caller doesn't own the returned Variable.
5051
Variable* Var(const std::string& name);
5152

5253
/// Create a variable with a scope-unique name.
54+
/// Caller doesn't own the returned Variable.
5355
Variable* Var(std::string* name = nullptr);
5456

5557
void EraseVars(const std::vector<std::string>& var_names);
5658

5759
/// Find a variable in the scope or any of its ancestors. Returns
5860
/// nullptr if cannot find.
61+
/// Caller doesn't own the returned Variable.
5962
Variable* FindVar(const std::string& name) const;
6063

6164
const Scope* parent() const { return parent_; }
@@ -78,13 +81,21 @@ class Scope {
7881
// Rename variable to a new name and return the new name
7982
std::string Rename(const std::string& origin_name) const;
8083

81-
Variable* FindVarLocally(const std::string& name) const;
82-
8384
private:
8485
// Call Scope::NewScope for a sub-scope.
8586
explicit Scope(Scope const* parent) : parent_(parent) {}
8687

87-
mutable std::unordered_map<std::string, Variable*> vars_;
88+
// Called by FindVar recursively.
89+
// Caller doesn't own the returned Variable.
90+
Variable* FindVarInternal(const std::string& name) const;
91+
92+
// Called by FindVarInternal and Var.
93+
// Caller doesn't own the returned Variable.
94+
Variable* FindVarLocally(const std::string& name) const;
95+
96+
mutable std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
97+
98+
// Scope in `kids_` are owned by this class.
8899
mutable std::list<Scope*> kids_;
89100
Scope const* parent_{nullptr};
90101

paddle/fluid/inference/analysis/helper.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ limitations under the License. */
1818
#include <unordered_map>
1919
#include <vector>
2020

21+
#include "paddle/fluid/framework/scope.h"
22+
#include "paddle/fluid/framework/variable.h"
2123
#include "paddle/fluid/platform/enforce.h"
2224

2325
namespace paddle {
@@ -107,6 +109,13 @@ class OrderedRegistry {
107109
std::vector<std::unique_ptr<T>> data_;
108110
};
109111

112+
template <typename T>
113+
T &GetFromScope(const framework::Scope &scope, const std::string &name) {
114+
framework::Variable *var = scope.FindVar(name);
115+
PADDLE_ENFORCE(var != nullptr);
116+
return *var->GetMutable<T>();
117+
}
118+
110119
} // namespace analysis
111120
} // namespace inference
112121
} // namespace paddle

paddle/fluid/inference/tensorrt/convert/CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Add TRT tests
2-
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine)
2+
nv_library(tensorrt_converter
3+
SRCS mul_op.cc conv2d_op.cc fc_op.cc
4+
DEPS tensorrt_engine mul_op)
5+
6+
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
7+
${FLUID_CORE_MODULES} tensorrt_engine tensorrt_converter)
8+
39
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
410
nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc
511
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)

paddle/fluid/inference/tensorrt/convert/activation_op.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include "paddle/fluid/framework/op_registry.h"
1516
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
1617

1718
namespace paddle {
@@ -37,8 +38,8 @@ class ReluOpConverter : public OpConverter {
3738
}
3839
};
3940

40-
REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter);
41-
4241
} // namespace tensorrt
4342
} // namespace inference
4443
} // namespace paddle
44+
45+
REGISTER_TRT_OP_CONVERTER(relu, ReluOpConverter);

paddle/fluid/inference/tensorrt/convert/conv2d_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ class Conv2dOpConverter : public OpConverter {
2222
public:
2323
Conv2dOpConverter() {}
2424
void operator()(const framework::proto::OpDesc& op,
25-
const framework::Scope& scope) override {
25+
const framework::Scope& scope, bool test_mode) override {
2626
LOG(INFO)
2727
<< "convert a fluid conv2d op to tensorrt conv layer without bias";
2828
}
2929
};
3030

31-
REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter);
32-
3331
} // namespace tensorrt
3432
} // namespace inference
3533
} // namespace paddle
34+
35+
REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter);

paddle/fluid/inference/tensorrt/convert/fc_op.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ void ReorderCKtoKC(TensorRTEngine::Weight& iweights,
5656
class FcOpConverter : public OpConverter {
5757
public:
5858
void operator()(const framework::proto::OpDesc& op,
59-
const framework::Scope& scope) override {
59+
const framework::Scope& scope, bool test_mode) override {
6060
VLOG(4) << "convert a fluid fc op to tensorrt fc layer without bias";
6161

6262
framework::OpDesc op_desc(op, nullptr);
@@ -106,14 +106,16 @@ class FcOpConverter : public OpConverter {
106106
n_output, weight.get(), bias.get());
107107

108108
auto output_name = op_desc.Output("Out").front();
109-
engine_->DeclareOutput(layer, 0, output_name);
109+
engine_->SetITensor(output_name, layer->getOutput(0));
110+
if (test_mode) {
111+
engine_->DeclareOutput(output_name);
112+
}
110113
}
111114
};
112115

113-
REGISTER_TRT_OP_CONVERTER(fc, FcOpConverter);
114-
115116
} // namespace tensorrt
116117
} // namespace inference
117118
} // namespace paddle
118119

120+
REGISTER_TRT_OP_CONVERTER(fc, FcOpConverter);
119121
USE_OP(mul);

paddle/fluid/inference/tensorrt/convert/mul_op.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ namespace tensorrt {
2323
*/
2424
class MulOpConverter : public OpConverter {
2525
public:
26-
MulOpConverter() {}
2726
void operator()(const framework::proto::OpDesc& op,
28-
const framework::Scope& scope) override {
27+
const framework::Scope& scope, bool test_mode) override {
2928
VLOG(4) << "convert a fluid mul op to tensorrt mul layer without bias";
3029

3130
framework::OpDesc op_desc(op, nullptr);
@@ -37,12 +36,18 @@ class MulOpConverter : public OpConverter {
3736
engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1), false,
3837
*const_cast<nvinfer1::ITensor*>(input2), false);
3938

40-
engine_->DeclareOutput(layer, 0, op_desc.Output("Out")[0]);
39+
auto output_name = op_desc.Output("Out")[0];
40+
engine_->SetITensor(output_name, layer->getOutput(0));
41+
if (test_mode) { // the test framework can not determine which is the
42+
// output, so place the declaration inside.
43+
engine_->DeclareOutput(output_name);
44+
}
4145
}
4246
};
4347

44-
REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter);
45-
4648
} // namespace tensorrt
4749
} // namespace inference
4850
} // namespace paddle
51+
52+
USE_OP(mul);
53+
REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter);

paddle/fluid/inference/tensorrt/convert/op_converter.h

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include <string>
1818
#include <unordered_map>
1919
#include "paddle/fluid/framework/block_desc.h"
20+
#include "paddle/fluid/framework/op_registry.h"
2021
#include "paddle/fluid/framework/scope.h"
2122
#include "paddle/fluid/inference/tensorrt/engine.h"
2223
#include "paddle/fluid/inference/utils/singleton.h"
@@ -34,12 +35,15 @@ class OpConverter {
3435

3536
// Converter logic for an op.
3637
virtual void operator()(const framework::proto::OpDesc& op,
37-
const framework::Scope& scope) {}
38+
const framework::Scope& scope,
39+
bool test_mode = false) {}
3840

39-
// Convert a single fluid operaotr and add the corresponding layer to TRT.
41+
// Convert a single fluid operator and add the corresponding layer to TRT.
42+
// test_mode: whether the instance executes in an unit test.
4043
void ConvertOp(const framework::proto::OpDesc& op,
4144
const std::unordered_set<std::string>& parameters,
42-
const framework::Scope& scope, TensorRTEngine* engine) {
45+
const framework::Scope& scope, TensorRTEngine* engine,
46+
bool test_mode = false) {
4347
framework::OpDesc op_desc(op, nullptr);
4448

4549
OpConverter* it{nullptr};
@@ -57,7 +61,7 @@ class OpConverter {
5761
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]",
5862
op_desc.Type());
5963
it->SetEngine(engine);
60-
(*it)(op, scope);
64+
(*it)(op, scope, test_mode);
6165
}
6266

6367
// convert fluid block to tensorrt network
@@ -77,6 +81,9 @@ class OpConverter {
7781
// TensorRT engine
7882
TensorRTEngine* engine_{nullptr};
7983

84+
protected:
85+
bool test_mode_;
86+
8087
private:
8188
// registered op converter map, whose key is the fluid op type, and value is
8289
// the pointer position of corresponding OpConverter class.
@@ -85,13 +92,24 @@ class OpConverter {
8592
framework::Scope* scope_{nullptr};
8693
};
8794

88-
#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
89-
struct trt_##op_type__##_converter { \
90-
trt_##op_type__##_converter() { \
91-
Registry<OpConverter>::Register<Converter__>(#op_type__); \
92-
} \
93-
}; \
94-
trt_##op_type__##_converter trt_##op_type__##_converter__;
95+
#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
96+
struct trt_##op_type__##_converter : public ::paddle::framework::Registrar { \
97+
trt_##op_type__##_converter() { \
98+
::paddle::inference:: \
99+
Registry<paddle::inference::tensorrt::OpConverter>::Register< \
100+
::paddle::inference::tensorrt::Converter__>(#op_type__); \
101+
} \
102+
}; \
103+
trt_##op_type__##_converter trt_##op_type__##_converter__; \
104+
int TouchConverterRegister_##op_type__() { \
105+
trt_##op_type__##_converter__.Touch(); \
106+
return 0; \
107+
}
108+
109+
#define USE_TRT_CONVERTER(op_type__) \
110+
extern int TouchConverterRegister_##op_type__(); \
111+
static int use_op_converter_trt_##op_type__ __attribute__((unused)) = \
112+
TouchConverterRegister_##op_type__();
95113

96114
} // namespace tensorrt
97115
} // namespace inference

paddle/fluid/inference/tensorrt/convert/test_op_converter.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,5 @@ TEST(OpConverter, ConvertBlock) {
3636
} // namespace tensorrt
3737
} // namespace inference
3838
} // namespace paddle
39+
40+
USE_TRT_CONVERTER(conv2d)

0 commit comments

Comments
 (0)