Skip to content

Commit b6d36aa

Browse files
committed
lower aten::conv1d and aten::conv_transpose1d to aten:convolution
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent 01417cd commit b6d36aa

File tree

8 files changed

+201
-150
lines changed

8 files changed

+201
-150
lines changed

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 24 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,18 @@ namespace converters {
1010
namespace impl {
1111
namespace {
1212

13-
bool add_conv_deconv(
14-
ConversionCtx* ctx,
15-
const torch::jit::Node* n,
16-
args& args,
17-
nvinfer1::Dims& stride,
18-
nvinfer1::Dims& padding,
19-
nvinfer1::Dims& dilation,
20-
bool transposed,
21-
nvinfer1::Dims& out_padding,
22-
int64_t groups) {
13+
bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
2314
// Input to conv/deconv
2415
auto in = args[0].ITensor();
2516

17+
// Conv /deconv parameters
18+
auto stride = util::toDims(args[3].unwrapToIntList());
19+
auto padding = util::toDims(args[4].unwrapToIntList());
20+
auto dilation = util::toDims(args[5].unwrapToIntList());
21+
bool transposed = args[6].unwrapToBool();
22+
auto out_padding = util::toDims(args[7].unwrapToIntList());
23+
int64_t groups = args[8].unwrapToInt();
24+
2625
// Reshape the parameters to 2D if needed
2726
if (stride.nbDims == 1) {
2827
stride = util::unsqueezeDims(stride, 1, 1);
@@ -175,69 +174,31 @@ bool add_conv_deconv(
175174
return true;
176175
}
177176

178-
auto conv_registrations TRTORCH_UNUSED =
179-
RegisterNodeConversionPatterns()
180-
.pattern({
181-
R"SIG(aten::_convolution(Tensor input, Tensor weight,
177+
auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
178+
.pattern({
179+
R"SIG(aten::_convolution(Tensor input, Tensor weight,
182180
Tensor? bias, int[] stride, int[] padding,
183181
int[] dilation, bool transposed,
184182
int[] output_padding, int groups, bool benchmark,
185183
bool deterministic, bool cudnn_enabled, bool allow_tf32) -> (Tensor))SIG",
186-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
187-
// Conv /deconv parameters
188-
auto stride = util::toDims(args[3].unwrapToIntList());
189-
auto padding = util::toDims(args[4].unwrapToIntList());
190-
auto dilation = util::toDims(args[5].unwrapToIntList());
191-
bool transposed = args[6].unwrapToBool();
192-
auto out_padding = util::toDims(args[7].unwrapToIntList());
193-
int64_t groups = args[8].unwrapToInt();
194-
return add_conv_deconv(ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
195-
}})
196-
.pattern({
197-
R"SIG(aten::_convolution.deprecated(Tensor input, Tensor weight,
184+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
185+
return add_conv_deconv(ctx, n, args);
186+
}})
187+
.pattern({
188+
R"SIG(aten::_convolution.deprecated(Tensor input, Tensor weight,
198189
Tensor? bias, int[] stride, int[] padding,
199190
int[] dilation, bool transposed,
200191
int[] output_padding, int groups, bool benchmark,
201192
bool deterministic, bool cudnn_enabled) -> (Tensor))SIG",
202-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
203-
// This pattern is only matched for traced JIT models which do not
204-
// have allow_tf32 bool in the function signature. The TRT conversion
205-
// code is exactly same as the above call.
206-
auto stride = util::toDims(args[3].unwrapToIntList());
207-
auto padding = util::toDims(args[4].unwrapToIntList());
208-
auto dilation = util::toDims(args[5].unwrapToIntList());
209-
bool transposed = args[6].unwrapToBool();
210-
auto out_padding = util::toDims(args[7].unwrapToIntList());
211-
int64_t groups = args[8].unwrapToInt();
212-
return add_conv_deconv(ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
213-
}})
214-
.pattern(
215-
{R"SIG(aten::conv1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor)SIG",
216-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
217-
// Conv /deconv parameters
218-
auto stride = util::toDims(args[3].unwrapToIntList());
219-
auto padding = util::toDims(args[4].unwrapToIntList());
220-
auto dilation = util::toDims(args[5].unwrapToIntList());
221-
bool transposed = false;
222-
nvinfer1::Dims out_padding{1, {0}};
223-
int64_t groups = args[6].unwrapToInt();
224-
return add_conv_deconv(ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
225-
}})
226-
.pattern(
227-
{R"SIG(aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor)SIG",
228-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
229-
// Conv /deconv parameters
230-
auto stride = util::toDims(args[3].unwrapToIntList());
231-
auto padding = util::toDims(args[4].unwrapToIntList());
232-
auto out_padding = util::toDims(args[5].unwrapToIntList());
233-
bool transposed = true;
234-
int64_t groups = args[6].unwrapToInt();
235-
auto dilation = util::toDims(args[7].unwrapToIntList());
236-
return add_conv_deconv(ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
237-
}});
193+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
194+
// This pattern is only matched for traced JIT models which do not
195+
// have allow_tf32 bool in the function signature. The TRT conversion
196+
// code is exactly same as the above call.
197+
return add_conv_deconv(ctx, n, args);
198+
}});
238199
} // namespace
239200
} // namespace impl
240201
} // namespace converters
241202
} // namespace conversion
242203
} // namespace core
243-
} // namespace trtorch
204+
} // namespace trtorch

core/lowering/lowering.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
4646
passes::RemoveContiguous(g);
4747
passes::RemoveDropout(g);
4848
passes::LinearToAddMM(g);
49+
passes::Conv1DToConvolution(g);
50+
passes::ConvTransposed1DToConvolution(g);
4951
passes::Conv2DToConvolution(g);
5052
passes::Conv3DToConvolution(g);
5153
passes::FuseAddMMBranches(g);

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ config_setting(
1010
cc_library(
1111
name = "passes",
1212
srcs = [
13+
"conv1d_to_convolution.cpp",
1314
"conv2d_to_convolution.cpp",
1415
"conv3d_to_convolution.cpp",
1516
"exception_elimination.cpp",
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string conv1d_pattern = R"IR(
12+
graph(%x, %w, %b, %s, %p, %d, %g):
13+
%4 : Tensor = aten::conv1d(%x, %w, %b, %s, %p, %d, %g)
14+
return (%4))IR";
15+
std::string convolution_pattern = R"IR(
16+
graph(%x, %w, %b, %s, %p, %d, %g):
17+
%1 : bool = prim::Constant[value=0]()
18+
%2 : int[] = prim::Constant[value=[0]]()
19+
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
20+
return (%4))IR";
21+
22+
torch::jit::SubgraphRewriter map_conv1d_to_convolution;
23+
map_conv1d_to_convolution.RegisterRewritePattern(conv1d_pattern, convolution_pattern);
24+
map_conv1d_to_convolution.runOnGraph(graph);
25+
LOG_GRAPH("Post map conv1d -> _convolution: " << *graph);
26+
}
27+
28+
void ConvTransposed1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
29+
std::string conv_transpose1d_pattern = R"IR(
30+
graph(%x, %w, %b, %s, %p, %o, %g, %d):
31+
%4 : Tensor = aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d)
32+
return (%4))IR";
33+
std::string convolution_pattern = R"IR(
34+
graph(%x, %w, %b, %s, %p, %o, %g, %d):
35+
%1 : bool = prim::Constant[value=1]()
36+
%2 : bool = prim::Constant[value=1]()
37+
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
38+
return (%4))IR";
39+
40+
torch::jit::SubgraphRewriter map_conv_transpose1d_to_convolution;
41+
map_conv_transpose1d_to_convolution.RegisterRewritePattern(conv_transpose1d_pattern, convolution_pattern);
42+
map_conv_transpose1d_to_convolution.runOnGraph(graph);
43+
LOG_GRAPH("Post map conv_transpose1d -> _convolution: " << *graph);
44+
}
45+
46+
} // namespace passes
47+
} // namespace lowering
48+
} // namespace core
49+
} // namespace trtorch

core/lowering/passes/passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ void NotateModuleForFallback(
1212
std::string mod_name,
1313
std::string method_name,
1414
std::unordered_set<std::string> forced_fallback_modules);
15+
void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
16+
void ConvTransposed1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1517
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1618
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1719
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);

tests/core/conversion/converters/test_conv_deconv.cpp

Lines changed: 1 addition & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,6 @@
1010
// int[] output_padding, int groups, bool benchmark,
1111
// bool deterministic, bool cudnn_enabled) -> (Tensor)
1212

13-
// aten::conv1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) ->
14-
// Tensor
15-
16-
// aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding,
17-
// int groups, int[] dilation) -> Tensor
18-
1913
void conv_test_helper(std::string graph_ir) {
2014
auto g = std::make_shared<torch::jit::Graph>();
2115
torch::jit::parseIR(graph_ir, g.get());
@@ -122,86 +116,6 @@ TEST(Converters, ATenConvolution1dConvertsCorrectly) {
122116
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
123117
}
124118

125-
TEST(Converters, ATenConv1dConvertsCorrectly) {
126-
const auto graph = R"IR(
127-
graph(%0 : Tensor,
128-
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
129-
%2 : Float(3)):
130-
%3 : int = prim::Constant[value=1]()
131-
%4 : int = prim::Constant[value=0]()
132-
%5 : int = prim::Constant[value=1]()
133-
%8 : int[] = prim::ListConstruct(%3)
134-
%9 : int[] = prim::ListConstruct(%4)
135-
%10 : int[] = prim::ListConstruct(%5)
136-
%12 : Tensor = aten::conv1d(%0, %1, %2, %8, %9, %10, %3)
137-
return (%12))IR";
138-
139-
auto g = std::make_shared<torch::jit::Graph>();
140-
torch::jit::parseIR(graph, g.get());
141-
142-
auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA});
143-
auto w = at::randint(1, 2, {4, 3, 3}, {at::kCUDA});
144-
auto b = at::randint(1, 10, {4}, {at::kCUDA});
145-
146-
auto jit_in = at::clone(in);
147-
auto jit_w = at::clone(w);
148-
auto jit_b = at::clone(b);
149-
150-
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
151-
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
152-
153-
auto trt_in = at::clone(in);
154-
auto trt_w = at::clone(w);
155-
auto trt_b = at::clone(b);
156-
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
157-
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
158-
159-
auto trt = trt_results[0].reshape(jit_results[0].sizes());
160-
161-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
162-
}
163-
164-
TEST(Converters, ATenConvTranspose1dConvertsCorrectly) {
165-
const auto graph = R"IR(
166-
graph(%0 : Tensor,
167-
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
168-
%2 : Float(3)):
169-
%3 : int = prim::Constant[value=1]()
170-
%4 : int = prim::Constant[value=0]()
171-
%5 : int = prim::Constant[value=1]()
172-
%6 : int = prim::Constant[value=0]()
173-
%8 : int[] = prim::ListConstruct(%3)
174-
%9 : int[] = prim::ListConstruct(%4)
175-
%10 : int[] = prim::ListConstruct(%5)
176-
%11 : int[] = prim::ListConstruct(%6)
177-
%12 : Tensor = aten::conv_transpose1d(%0, %1, %2, %8, %9, %11, %3, %10)
178-
return (%12))IR";
179-
180-
auto g = std::make_shared<torch::jit::Graph>();
181-
torch::jit::parseIR(graph, g.get());
182-
183-
auto in = at::randint(1, 2, {1, 8, 3}, {at::kCUDA});
184-
auto w = at::randint(1, 2, {8, 4, 3}, {at::kCUDA});
185-
auto b = at::randint(1, 10, {4}, {at::kCUDA});
186-
187-
auto jit_in = at::clone(in);
188-
auto jit_w = at::clone(w);
189-
auto jit_b = at::clone(b);
190-
191-
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
192-
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
193-
194-
auto trt_in = at::clone(in);
195-
auto trt_w = at::clone(w);
196-
auto trt_b = at::clone(b);
197-
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
198-
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
199-
200-
auto trt = trt_results[0].reshape(jit_results[0].sizes());
201-
202-
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
203-
}
204-
205119
TEST(Converters, ATenConvolutionNoBiasConvertsCorrectly) {
206120
const auto graph = R"IR(
207121
graph(%0 : Tensor,
@@ -740,4 +654,4 @@ TEST(Converters, ATenConvTransposeWithGroupConvertsCorrectly) {
740654
auto trt = trt_results[0].reshape(jit_results[0].sizes());
741655

742656
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
743-
}
657+
}

tests/core/lowering/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ cc_test(
2626
]
2727
)
2828

29+
lowering_test(
30+
name = "test_conv1d_pass",
31+
)
32+
2933
lowering_test(
3034
name = "test_remove_contiguous_pass",
3135
)
@@ -61,6 +65,7 @@ lowering_test(
6165
test_suite(
6266
name = "lowering_tests",
6367
tests = [
68+
":test_conv1d_pass",
6469
":test_linear_to_addmm",
6570
":test_module_fallback_passes",
6671
":test_operator_aliasing_pass",

0 commit comments

Comments
 (0)