Skip to content

Commit 53b401d

Browse files
committed
refine io_convert and op_convert
1 parent 2a2c83b commit 53b401d

File tree

9 files changed

+37
-39
lines changed

9 files changed

+37
-39
lines changed
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
22
nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda)
3-
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
43
set(ENGINE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/engine.cc)
54
add_subdirectory(convert)
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
nv_test(test_tensorrt_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
2-
nv_test(test_tensorrt_activation_op SRCS test_activation_op.cc ${ENGINE_FILE} activation_op.cc
1+
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
2+
nv_test(test_trt_activation_op SRCS test_activation_op.cc ${ENGINE_FILE} activation_op.cc
33
DEPS ${FLUID_CORE_MODULES} activation_op)
4+
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)

paddle/fluid/inference/tensorrt/io_converter.cc renamed to paddle/fluid/inference/tensorrt/convert/io_converter.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/inference/tensorrt/io_converter.h"
15+
#include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
1616
#include <cuda.h>
1717
#include "paddle/fluid/platform/enforce.h"
1818

@@ -50,7 +50,7 @@ class DefaultInputConverter : public EngineInputConverter {
5050
}
5151
};
5252

53-
REGISTER_TENSORRT_INPUT_CONVERTER(mul, DefaultInputConverter);
53+
REGISTER_TENSORRT_INPUT_CONVERTER(default, DefaultInputConverter);
5454

5555
} // namespace tensorrt
5656
} // namespace inference

paddle/fluid/inference/tensorrt/io_converter.h renamed to paddle/fluid/inference/tensorrt/convert/io_converter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ class EngineInputConverter {
4040
static void Run(const std::string& in_op_type, const LoDTensor& in, void* out,
4141
size_t max_size, cudaStream_t* stream) {
4242
PADDLE_ENFORCE(stream != nullptr);
43-
auto* converter = Registry<EngineInputConverter>::Lookup(in_op_type);
43+
auto* converter = Registry<EngineInputConverter>::Lookup(
44+
in_op_type, "default" /* default_type */);
4445
PADDLE_ENFORCE_NOT_NULL(converter);
4546
converter->SetStream(stream);
4647
(*converter)(in, out, max_size);

paddle/fluid/inference/tensorrt/convert/op_converter.h

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License. */
1919
#include "paddle/fluid/framework/block_desc.h"
2020
#include "paddle/fluid/framework/scope.h"
2121
#include "paddle/fluid/inference/tensorrt/engine.h"
22+
#include "paddle/fluid/inference/utils/singleton.h"
2223

2324
namespace paddle {
2425
namespace inference {
@@ -32,34 +33,23 @@ class OpConverter {
3233
OpConverter() {}
3334
virtual void operator()(const framework::OpDesc& op) {}
3435

35-
void Execute(const framework::OpDesc& op, TensorRTEngine* engine) {
36+
void Run(const framework::OpDesc& op, TensorRTEngine* engine) {
3637
std::string type = op.Type();
37-
auto it = converters_.find(type);
38-
PADDLE_ENFORCE(it != converters_.end(), "no OpConverter for optype [%s]",
39-
type);
40-
it->second->SetEngine(engine);
41-
(*it->second)(op);
42-
}
43-
44-
static OpConverter& Global() {
45-
static auto* x = new OpConverter;
46-
return *x;
47-
}
48-
49-
template <typename T>
50-
void Register(const std::string& key) {
51-
converters_[key] = new T;
38+
auto* it = Registry<OpConverter>::Lookup(type);
39+
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", type);
40+
it->SetEngine(engine);
41+
(*it)(op);
5242
}
5343

5444
// convert fluid op to tensorrt layer
5545
void ConvertOp(const framework::OpDesc& op, TensorRTEngine* engine) {
56-
OpConverter::Global().Execute(op, engine);
46+
OpConverter::Run(op, engine);
5747
}
5848

5949
// convert fluid block to tensorrt network
6050
void ConvertBlock(const framework::BlockDesc& block, TensorRTEngine* engine) {
6151
for (auto op : block.AllOps()) {
62-
OpConverter::Global().Execute(*op, engine);
52+
OpConverter::Run(*op, engine);
6353
}
6454
}
6555

@@ -78,12 +68,12 @@ class OpConverter {
7868
framework::Scope* scope_{nullptr};
7969
};
8070

81-
#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
82-
struct trt_##op_type__##_converter { \
83-
trt_##op_type__##_converter() { \
84-
OpConverter::Global().Register<Converter__>(#op_type__); \
85-
} \
86-
}; \
71+
#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
72+
struct trt_##op_type__##_converter { \
73+
trt_##op_type__##_converter() { \
74+
Registry<OpConverter>::Register<Converter__>(#op_type__); \
75+
} \
76+
}; \
8777
trt_##op_type__##_converter trt_##op_type__##_converter__;
8878

8979
} // namespace tensorrt

paddle/fluid/inference/tensorrt/convert/test_activation_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace paddle {
2626
namespace inference {
2727
namespace tensorrt {
2828

29-
void compare(float input, float expect) {
29+
void Compare(float input, float expect) {
3030
framework::Scope scope;
3131
platform::CUDAPlace place;
3232
platform::CUDADeviceContext ctx(place);
@@ -85,8 +85,8 @@ void compare(float input, float expect) {
8585
}
8686

8787
TEST(OpConverter, ConvertRelu) {
88-
compare(1, 1); // relu(1) = 1
89-
compare(-5, 0); // relu(-5) = 0
88+
Compare(1, 1); // relu(1) = 1
89+
Compare(-5, 0); // relu(-5) = 0
9090
}
9191

9292
} // namespace tensorrt

paddle/fluid/inference/tensorrt/test_io_converter.cc renamed to paddle/fluid/inference/tensorrt/convert/test_io_converter.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/framework/lod_tensor.h"
16-
#include "paddle/fluid/inference/tensorrt/io_converter.h"
16+
#include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
1717

1818
#include <gtest/gtest.h>
1919

@@ -34,7 +34,7 @@ TEST_F(EngineInputConverterTester, DefaultCPU) {
3434
ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
3535

3636
cudaStream_t stream;
37-
EngineInputConverter::Run("mul", tensor, buffer, tensor.memory_size(),
37+
EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(),
3838
&stream);
3939
}
4040

@@ -44,7 +44,7 @@ TEST_F(EngineInputConverterTester, DefaultGPU) {
4444
ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
4545

4646
cudaStream_t stream;
47-
EngineInputConverter::Run("mul", tensor, buffer, tensor.memory_size(),
47+
EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(),
4848
&stream);
4949
}
5050

paddle/fluid/inference/tensorrt/convert/test_op_converter.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace paddle {
2020
namespace inference {
2121
namespace tensorrt {
2222

23-
TEST(BlockConverter, ConvertBlock) {
23+
TEST(OpConverter, ConvertBlock) {
2424
framework::ProgramDesc prog;
2525
auto* block = prog.MutableBlock(0);
2626
auto* mul_op = block->AppendOp();

paddle/fluid/inference/utils/singleton.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <string>
1718
#include <unordered_map>
1819
#include "paddle/fluid/platform/enforce.h"
1920

@@ -49,9 +50,15 @@ struct Registry {
4950
items_[name] = new ItemChild;
5051
}
5152

52-
static ItemParent* Lookup(const std::string& name) {
53+
static ItemParent* Lookup(const std::string& name,
54+
const std::string& default_name = "") {
5355
auto it = items_.find(name);
54-
if (it == items_.end()) return nullptr;
56+
if (it == items_.end()) {
57+
if (default_name == "")
58+
return nullptr;
59+
else
60+
return items_.find(default_name)->second;
61+
}
5562
return it->second;
5663
}
5764

0 commit comments

Comments
 (0)