@@ -9,82 +9,99 @@ namespace conversion {
9
9
namespace converters {
10
10
namespace impl {
11
11
namespace {
12
- auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern({
13
- R"SIG( aten::_convolution(Tensor input, Tensor weight,
14
- Tensor? bias, int[] stride, int[] padding,
15
- int[] dilation, bool transposed,
16
- int[] output_padding, int groups, bool benchmark,
17
- bool deterministic, bool cudnn_enabled) -> (Tensor))SIG" ,
18
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
19
- auto in = args[0 ].ITensor (); // assumes non-static input Tensor
20
- auto w = Weights (ctx, args[1 ].unwrapToTensor ());
21
- auto stride = util::toDims (args[3 ].unwrapToIntList ());
22
- LOG_DEBUG (" stride: " << stride);
23
- auto padding = util::toDims (args[4 ].unwrapToIntList ());
24
- LOG_DEBUG (" padding: " << padding);
25
- auto dilation = util::toDims (args[5 ].unwrapToIntList ());
26
- LOG_DEBUG (" dilation: " << dilation);
27
- bool transposed = args[6 ].unwrapToBool ();
28
- auto out_padding = util::toDims (args[7 ].unwrapToIntList ());
29
- LOG_DEBUG (" out_padding: " << out_padding);
30
- int64_t groups = args[8 ].unwrapToInt ();
31
- LOG_DEBUG (" groups: " << groups);
32
12
33
- nvinfer1::ILayer* new_layer;
34
- if (transposed) {
35
- Weights bias;
36
- if (args[2 ].IValue ()->isTensor ()) {
37
- bias = Weights (ctx, args[2 ].unwrapToTensor ());
38
- } else {
39
- bias = Weights (ctx, torch::zeros (args[1 ].unwrapToTensor ().sizes ()[1 ] * groups));
40
- }
13
+ bool add_conv_deconv (ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
14
+ auto in = args[0 ].ITensor (); // assumes non-static input Tensor
15
+ auto w = Weights (ctx, args[1 ].unwrapToTensor ());
16
+ auto stride = util::toDims (args[3 ].unwrapToIntList ());
17
+ LOG_DEBUG (" stride: " << stride);
18
+ auto padding = util::toDims (args[4 ].unwrapToIntList ());
19
+ LOG_DEBUG (" padding: " << padding);
20
+ auto dilation = util::toDims (args[5 ].unwrapToIntList ());
21
+ LOG_DEBUG (" dilation: " << dilation);
22
+ bool transposed = args[6 ].unwrapToBool ();
23
+ auto out_padding = util::toDims (args[7 ].unwrapToIntList ());
24
+ LOG_DEBUG (" out_padding: " << out_padding);
25
+ int64_t groups = args[8 ].unwrapToInt ();
26
+ LOG_DEBUG (" groups: " << groups);
27
+
28
+ nvinfer1::ILayer* new_layer;
29
+ if (transposed) {
30
+ Weights bias;
31
+ if (args[2 ].IValue ()->isTensor ()) {
32
+ bias = Weights (ctx, args[2 ].unwrapToTensor ());
33
+ } else {
34
+ bias = Weights (ctx, torch::zeros (args[1 ].unwrapToTensor ().sizes ()[1 ] * groups));
35
+ }
41
36
42
- // shape of deconvolution's weight: [in, out/groups, ...]
43
- auto deconv = ctx->net ->addDeconvolutionNd (
44
- *in, args[1 ].unwrapToTensor ().sizes ()[1 ] * groups, w.kernel_shape , w.data , bias.data );
45
- TRTORCH_CHECK (deconv, " Unable to create deconvolution layer from node: " << *n);
37
+ // shape of deconvolution's weight: [in, out/groups, ...]
38
+ auto deconv = ctx->net ->addDeconvolutionNd (
39
+ *in, args[1 ].unwrapToTensor ().sizes ()[1 ] * groups, w.kernel_shape , w.data , bias.data );
40
+ TRTORCH_CHECK (deconv, " Unable to create deconvolution layer from node: " << *n);
46
41
47
- deconv->setStrideNd (stride);
48
- deconv->setPaddingNd (padding);
49
- #if NV_TENSORRT_MAJOR > 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR = = 1)
50
- deconv->setDilationNd (dilation);
51
- deconv->setNbGroups (groups);
42
+ deconv->setStrideNd (stride);
43
+ deconv->setPaddingNd (padding);
44
+ #if NV_TENSORRT_MAJOR > 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR > = 1)
45
+ deconv->setDilationNd (dilation);
46
+ deconv->setNbGroups (groups);
52
47
#else
53
- TRTORCH_CHECK (groups == 1 , " for deconv with groups > 1, require TensorRT version >= 7.1" );
54
- for (auto it = dilation. begin (); it != dilation.end (); ++it ) {
55
- TRTORCH_CHECK (*it == 1 , " for deconv with dilation > 1, require TensorRT version >= 7.1" );
56
- }
48
+ TRTORCH_CHECK (groups == 1 , " for deconv with groups > 1, require TensorRT version >= 7.1" );
49
+ for (int idx = 0 ; idx < dilation.nbDims ; idx++ ) {
50
+ TRTORCH_CHECK (dilation. d [idx] == 1 , " for deconv with dilation > 1, require TensorRT version >= 7.1" );
51
+ }
57
52
#endif
58
- new_layer = deconv;
59
- } else {
60
- Weights bias;
61
- if (args[2 ].IValue ()->isTensor ()) {
62
- bias = Weights (ctx, args[2 ].unwrapToTensor ());
63
- } else {
64
- bias = Weights (ctx, torch::zeros (args[1 ].unwrapToTensor ().sizes ()[0 ]));
65
- }
53
+ new_layer = deconv;
54
+ } else {
55
+ Weights bias;
56
+ if (args[2 ].IValue ()->isTensor ()) {
57
+ bias = Weights (ctx, args[2 ].unwrapToTensor ());
58
+ } else {
59
+ bias = Weights (ctx, torch::zeros (args[1 ].unwrapToTensor ().sizes ()[0 ]));
60
+ }
66
61
67
- // shape of convolution's weight: [out, in/groups, ...]
68
- auto conv =
69
- ctx->net ->addConvolutionNd (*in, args[1 ].unwrapToTensor ().sizes ()[0 ], w.kernel_shape , w.data , bias.data );
70
- TRTORCH_CHECK (conv, " Unable to create convolution layer from node: " << *n);
62
+ // shape of convolution's weight: [out, in/groups, ...]
63
+ auto conv = ctx->net ->addConvolutionNd (*in, args[1 ].unwrapToTensor ().sizes ()[0 ], w.kernel_shape , w.data , bias.data );
64
+ TRTORCH_CHECK (conv, " Unable to create convolution layer from node: " << *n);
71
65
72
- conv->setStrideNd (stride);
73
- conv->setPaddingMode (nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN );
74
- conv->setPaddingNd (padding);
75
- conv->setPostPadding (out_padding);
76
- conv->setDilationNd (dilation);
77
- conv->setNbGroups (groups);
78
- new_layer = conv;
79
- }
80
- new_layer->setName (util::node_info (n).c_str ());
66
+ conv->setStrideNd (stride);
67
+ conv->setPaddingMode (nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN );
68
+ conv->setPaddingNd (padding);
69
+ conv->setPostPadding (out_padding);
70
+ conv->setDilationNd (dilation);
71
+ conv->setNbGroups (groups);
72
+ new_layer = conv;
73
+ }
74
+ new_layer->setName (util::node_info (n).c_str ());
81
75
82
- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], new_layer->getOutput (0 ));
76
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], new_layer->getOutput (0 ));
83
77
84
- LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
78
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
85
79
86
- return true ;
87
- }});
80
+ return true ;
81
+ }
82
+
83
+ auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
84
+ .pattern({
85
+ R"SIG( aten::_convolution(Tensor input, Tensor weight,
86
+ Tensor? bias, int[] stride, int[] padding,
87
+ int[] dilation, bool transposed,
88
+ int[] output_padding, int groups, bool benchmark,
89
+ bool deterministic, bool cudnn_enabled, bool allow_tf32) -> (Tensor))SIG" ,
90
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
91
+ return add_conv_deconv (ctx, n, args);
92
+ }})
93
+ .pattern({
94
+ R"SIG( aten::_convolution.deprecated(Tensor input, Tensor weight,
95
+ Tensor? bias, int[] stride, int[] padding,
96
+ int[] dilation, bool transposed,
97
+ int[] output_padding, int groups, bool benchmark,
98
+ bool deterministic, bool cudnn_enabled) -> (Tensor))SIG" ,
99
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
100
+ // This pattern is only matched for traced JIT models which do not
101
+ // have allow_tf32 bool in the function signature. The TRT conversion
102
+ // code is exactly same as the above call.
103
+ return add_conv_deconv (ctx, n, args);
104
+ }});
88
105
} // namespace
89
106
} // namespace impl
90
107
} // namespace converters
0 commit comments