Skip to content

Commit f06735a

Browse files
committed
add patch file to disable elu converter in core library
Signed-off-by: Bo Wang <[email protected]>
1 parent d012322 commit f06735a

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

examples/custom_converters/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ Converter for aten::elu requested, but no such converter was found.
1212
If you need a converter for this operator, you can try implementing one yourself
1313
or request a converter: https://www.github.com/NVIDIA/TRTorch/issues
1414

15+
Note that ELU converter is now supported in our library. If you want to get above
16+
error and run the example in this document, you can either:
17+
1. get the source code, go to root directory, then run: <br />
18+
`git apply ./examples/custom_converters/elu_converter/disable_core_elu.patch`
19+
2. If you are using a pre-downloaded release of TRTorch, you need to make sure that
20+
it doesn't support elu operator in default. (TRTorch <= v0.1.0)
21+
1522
## Writing Converter in C++
1623
We can register a converter for this operator in our application. You can find more
1724
information on all the details of writing converters in the contributors documentation
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
diff --git a/core/conversion/converters/impl/activation.cpp b/core/conversion/converters/impl/activation.cpp
2+
index 64edeed..5279413 100644
3+
--- a/core/conversion/converters/impl/activation.cpp
4+
+++ b/core/conversion/converters/impl/activation.cpp
5+
@@ -152,21 +152,6 @@ auto acthardtanh TRTORCH_UNUSED =
6+
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
7+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
8+
return true;
9+
- }})
10+
- .pattern({"aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)",
11+
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
12+
- auto in = args[0].ITensorOrFreeze(ctx);
13+
- auto alpha = args[1].unwrapToDouble();
14+
-
15+
- auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kELU);
16+
- TRTORCH_CHECK(new_layer, "Unable to create layer for aten::elu");
17+
- new_layer->setAlpha(alpha);
18+
-
19+
- new_layer->setName(trtorch::core::util::node_info(n).c_str());
20+
-
21+
- auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
22+
- LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
23+
- return true;
24+
}});
25+
26+
} // namespace

0 commit comments

Comments
 (0)