@@ -8,37 +8,42 @@ namespace converters {
8
8
namespace impl {
9
9
namespace {
10
10
11
- bool relu (ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
12
- auto in = args[0 ].ITensor ();
11
+ #define convert (act, trt_type ) \
12
+ bool act (ConversionCtx* ctx, const torch::jit::Node* n, args& args) { \
13
+ auto in = args[0 ].ITensor (); \
14
+ \
15
+ auto new_layer = \
16
+ ctx->net ->addActivation (*in, nvinfer1::ActivationType::trt_type); \
17
+ TRTORCH_CHECK (new_layer, \
18
+ " Unable to create " #act " layer from node: " << *n); \
19
+ \
20
+ new_layer->setName (util::node_info (n).c_str ()); \
21
+ auto out_value = n->outputs ()[0 ]; \
22
+ auto out_tensor = new_layer->getOutput (0 ); \
23
+ out_tensor->setName (out_value->debugName ().c_str ()); \
24
+ ctx->value_tensor_map [out_value] = out_tensor; \
25
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ()); \
26
+ \
27
+ return true ; \
28
+ } \
29
+ \
30
+ auto act##_registrations TRTORCH_UNUSED = \
31
+ RegisterNodeConversionPatterns () \
32
+ .pattern({" aten::" #act " (Tensor input) -> (Tensor)" , \
33
+ [](ConversionCtx *ctx, const torch::jit::Node *n, \
34
+ args &args) -> bool { return act (ctx, n, args); }}) \
35
+ .pattern({" aten::" #act " _(Tensor(a!) self) -> (Tensor(a!))" , \
36
+ [](ConversionCtx *ctx, const torch::jit::Node *n, \
37
+ args &args) -> bool { return act (ctx, n, args); }});
13
38
14
- auto new_layer = ctx->net ->addActivation (*in, nvinfer1::ActivationType::kRELU );
15
- TRTORCH_CHECK (new_layer, " Unable to create ReLU layer from node: " << *n);
39
+ convert (relu, kRELU );
40
+ convert (sigmoid, kSIGMOID );
41
+ convert (tanh, kTANH );
16
42
17
- new_layer->setName (util::node_info (n).c_str ());
18
- auto out_value = n->outputs ()[0 ];
19
- auto out_tensor = new_layer->getOutput (0 );
20
- out_tensor->setName (out_value->debugName ().c_str ());
21
- ctx->value_tensor_map [out_value] = out_tensor;
22
- LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
23
-
24
- return true ;
25
- }
26
-
27
- auto relu_registrations = RegisterNodeConversionPatterns()
28
- .pattern({
29
- " aten::relu(Tensor input) -> (Tensor)" ,
30
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
31
- return relu (ctx, n, args);
32
- }
33
- }).pattern({
34
- " aten::relu_(Tensor(a!) self) -> (Tensor(a!))" ,
35
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
36
- return relu (ctx, n, args);
37
- }
38
- });
43
+ #undef convert
39
44
} // namespace
40
45
} // namespace impl
41
46
} // namespace converters
42
47
} // namespace conversion
43
48
} // namespace core
44
- } // trtorch
49
+ } // namespace trtorch
0 commit comments