Skip to content

Commit d1c407f

Browse files
committed
Merge branch 'fix_conv1d' into 'release/1.0'
fix(aten::conv1d): Update namespace, fix typo in dest IR for conv1d See merge request adlsa/TRTorch!17
2 parents 540e135 + d53f136 commit d1c407f

File tree

5 files changed

+56
-85
lines changed

5 files changed

+56
-85
lines changed

core/lowering/passes/BUILD

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ config_setting(
1010
cc_library(
1111
name = "passes",
1212
srcs = [
13-
"conv1d_to_convolution.cpp",
14-
"conv2d_to_convolution.cpp",
15-
"conv3d_to_convolution.cpp",
13+
"convNd_to_convolution.cpp",
1614
"exception_elimination.cpp",
1715
"fuse_addmm_branches.cpp",
1816
"linear_to_addmm.cpp",

core/lowering/passes/conv2d_to_convolution.cpp

Lines changed: 0 additions & 33 deletions
This file was deleted.

core/lowering/passes/conv3d_to_convolution.cpp

Lines changed: 0 additions & 33 deletions
This file was deleted.

core/lowering/passes/conv1d_to_convolution.cpp renamed to core/lowering/passes/convNd_to_convolution.cpp

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#include "core/util/prelude.h"
44

5-
namespace trtorch {
5+
namespace torch_tensorrt {
66
namespace core {
77
namespace lowering {
88
namespace passes {
@@ -12,6 +12,7 @@ void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
1212
graph(%x, %w, %b, %s, %p, %d, %g):
1313
%4 : Tensor = aten::conv1d(%x, %w, %b, %s, %p, %d, %g)
1414
return (%4))IR";
15+
1516
std::string convolution_pattern = R"IR(
1617
graph(%x, %w, %b, %s, %p, %d, %g):
1718
%1 : bool = prim::Constant[value=0]()
@@ -43,7 +44,45 @@ void ConvTransposed1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
4344
LOG_GRAPH("Post map conv_transpose1d -> _convolution: " << *graph);
4445
}
4546

47+
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
48+
std::string conv2d_pattern = R"IR(
49+
graph(%x, %w, %b, %s, %p, %d, %g):
50+
%4 : Tensor = aten::conv2d(%x, %w, %b, %s, %p, %d, %g)
51+
return (%4))IR";
52+
std::string convolution_pattern = R"IR(
53+
graph(%x, %w, %b, %s, %p, %d, %g):
54+
%1 : bool = prim::Constant[value=0]()
55+
%2 : int[] = prim::Constant[value=[0, 0]]()
56+
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
57+
return (%4))IR";
58+
59+
// replace matmul + add pattern to linear
60+
torch::jit::SubgraphRewriter map_conv2d_to_convolution;
61+
map_conv2d_to_convolution.RegisterRewritePattern(conv2d_pattern, convolution_pattern);
62+
map_conv2d_to_convolution.runOnGraph(graph);
63+
LOG_GRAPH("Post map conv2d -> _convolution: " << *graph);
64+
}
65+
66+
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
67+
std::string conv3d_pattern = R"IR(
68+
graph(%x, %w, %b, %s, %p, %d, %g):
69+
%4 : Tensor = aten::conv3d(%x, %w, %b, %s, %p, %d, %g)
70+
return (%4))IR";
71+
std::string convolution_pattern = R"IR(
72+
graph(%x, %w, %b, %s, %p, %d, %g):
73+
%1 : bool = prim::Constant[value=0]()
74+
%2 : int[] = prim::Constant[value=[0, 0, 0]]()
75+
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
76+
return (%4))IR";
77+
78+
// replace matmul + add pattern to linear
79+
torch::jit::SubgraphRewriter map_conv3d_to_convolution;
80+
map_conv3d_to_convolution.RegisterRewritePattern(conv3d_pattern, convolution_pattern);
81+
map_conv3d_to_convolution.runOnGraph(graph);
82+
LOG_GRAPH("Post map conv3d -> _convolution: " << *graph);
83+
}
84+
4685
} // namespace passes
4786
} // namespace lowering
4887
} // namespace core
49-
} // namespace trtorch
88+
} // namespace torch_tensorrt

tests/core/lowering/test_conv1d_pass.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ TEST(LoweringPasses, Conv1dCorrectly) {
3535
%12 : Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3)
3636
return (%12))IR";
3737

38-
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
38+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
3939
auto sg = std::make_shared<torch::jit::Graph>();
4040
torch::jit::parseIR(source_graph, &*sg);
41-
trtorch::core::lowering::passes::Conv1DToConvolution(sg);
41+
torch_tensorrt::core::lowering::passes::Conv1DToConvolution(sg);
4242

4343
auto tg = std::make_shared<torch::jit::Graph>();
4444
torch::jit::parseIR(target_graph, &*tg);
@@ -50,13 +50,13 @@ TEST(LoweringPasses, Conv1dCorrectly) {
5050
auto trt_in = at::clone(in);
5151
auto trt_w = at::clone(w);
5252
auto trt_b = at::clone(b);
53-
auto params = trtorch::core::conversion::get_named_params(sg->inputs(), {trt_w, trt_b});
54-
auto trt_results_sg = trtorch::tests::util::RunGraphEngine(sg, params, {trt_in});
53+
auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b});
54+
auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in});
5555

56-
params = trtorch::core::conversion::get_named_params(tg->inputs(), {trt_w, trt_b});
57-
auto trt_results_tg = trtorch::tests::util::RunGraphEngine(tg, params, {trt_in});
56+
params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b});
57+
auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in});
5858

59-
ASSERT_TRUE(trtorch::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
59+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
6060
}
6161

6262
TEST(LoweringPasses, ConvTransposed1dCorrectly) {
@@ -92,10 +92,10 @@ TEST(LoweringPasses, ConvTransposed1dCorrectly) {
9292
%12 : Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %8, %output_padding, %5, %7, %7, %7, %7)
9393
return (%12))IR";
9494

95-
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
95+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
9696
auto sg = std::make_shared<torch::jit::Graph>();
9797
torch::jit::parseIR(source_graph, &*sg);
98-
trtorch::core::lowering::passes::ConvTransposed1DToConvolution(sg);
98+
torch_tensorrt::core::lowering::passes::ConvTransposed1DToConvolution(sg);
9999

100100
auto tg = std::make_shared<torch::jit::Graph>();
101101
torch::jit::parseIR(target_graph, &*tg);
@@ -107,11 +107,11 @@ TEST(LoweringPasses, ConvTransposed1dCorrectly) {
107107
auto trt_in = at::clone(in);
108108
auto trt_w = at::clone(w);
109109
auto trt_b = at::clone(b);
110-
auto params = trtorch::core::conversion::get_named_params(sg->inputs(), {trt_w, trt_b});
111-
auto trt_results_sg = trtorch::tests::util::RunGraphEngine(sg, params, {trt_in});
110+
auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {trt_w, trt_b});
111+
auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in});
112112

113-
params = trtorch::core::conversion::get_named_params(tg->inputs(), {trt_w, trt_b});
114-
auto trt_results_tg = trtorch::tests::util::RunGraphEngine(tg, params, {trt_in});
113+
params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {trt_w, trt_b});
114+
auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in});
115115

116-
ASSERT_TRUE(trtorch::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
116+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
117117
}

0 commit comments

Comments
 (0)