@@ -9,37 +9,42 @@ namespace converters {
9
9
namespace impl {
10
10
namespace {
11
11
12
- bool relu (ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
13
- auto in = args[0 ].ITensor ();
12
+ #define convert (act, trt_type ) \
13
+ bool act (ConversionCtx* ctx, const torch::jit::Node* n, args& args) { \
14
+ auto in = args[0 ].ITensor (); \
15
+ \
16
+ auto new_layer = \
17
+ ctx->net ->addActivation (*in, nvinfer1::ActivationType::trt_type); \
18
+ TRTORCH_CHECK (new_layer, \
19
+ " Unable to create " #act " layer from node: " << *n); \
20
+ \
21
+ new_layer->setName (util::node_info (n).c_str ()); \
22
+ auto out_value = n->outputs ()[0 ]; \
23
+ auto out_tensor = new_layer->getOutput (0 ); \
24
+ out_tensor->setName (out_value->debugName ().c_str ()); \
25
+ ctx->value_tensor_map [out_value] = out_tensor; \
26
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ()); \
27
+ \
28
+ return true ; \
29
+ } \
30
+ \
31
+ auto act##_registrations TRTORCH_UNUSED = \
32
+ RegisterNodeConversionPatterns () \
33
+ .pattern({" aten::" #act " (Tensor input) -> (Tensor)" , \
34
+ [](ConversionCtx *ctx, const torch::jit::Node *n, \
35
+ args &args) -> bool { return act (ctx, n, args); }}) \
36
+ .pattern({" aten::" #act " _(Tensor(a!) self) -> (Tensor(a!))" , \
37
+ [](ConversionCtx *ctx, const torch::jit::Node *n, \
38
+ args &args) -> bool { return act (ctx, n, args); }});
14
39
15
- auto new_layer = ctx->net ->addActivation (*in, nvinfer1::ActivationType::kRELU );
16
- TRTORCH_CHECK (new_layer, " Unable to create ReLU layer from node: " << *n);
40
+ convert (relu, kRELU );
41
+ convert (sigmoid, kSIGMOID );
42
+ convert (tanh, kTANH );
17
43
18
- new_layer->setName (util::node_info (n).c_str ());
19
- auto out_value = n->outputs ()[0 ];
20
- auto out_tensor = new_layer->getOutput (0 );
21
- out_tensor->setName (out_value->debugName ().c_str ());
22
- ctx->value_tensor_map [out_value] = out_tensor;
23
- LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
24
-
25
- return true ;
26
- }
27
-
28
- auto relu_registrations = RegisterNodeConversionPatterns()
29
- .pattern({
30
- " aten::relu(Tensor input) -> (Tensor)" ,
31
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
32
- return relu (ctx, n, args);
33
- }
34
- }).pattern({
35
- " aten::relu_(Tensor(a!) self) -> (Tensor(a!))" ,
36
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
37
- return relu (ctx, n, args);
38
- }
39
- });
44
+ #undef convert
40
45
} // namespace
41
46
} // namespace impl
42
47
} // namespace converters
43
48
} // namespace conversion
44
49
} // namespace core
45
- } // trtorch
50
+ } // namespace trtorch
0 commit comments