Skip to content

Commit 877516d

Browse files
authored
Merge pull request #362 from NVIDIA/bowa_core_elu
add ELU converter to core library
2 parents 37305c0 + 100c9f8 commit 877516d

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

core/conversion/converters/impl/activation.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,21 @@ auto acthardtanh TRTORCH_UNUSED =
152152
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
153153
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
154154
return true;
155+
}})
156+
.pattern({"aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)",
157+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
158+
auto in = args[0].ITensorOrFreeze(ctx);
159+
auto alpha = args[1].unwrapToDouble();
160+
161+
auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kELU);
162+
TRTORCH_CHECK(new_layer, "Unable to create layer for aten::elu");
163+
new_layer->setAlpha(alpha);
164+
165+
new_layer->setName(trtorch::core::util::node_info(n).c_str());
166+
167+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
168+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
169+
return true;
155170
}});
156171

157172
} // namespace

tests/core/conversion/converters/test_activation.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,5 +175,28 @@ TEST(Converters, ATenLeakyReluConvertsCorrectly) {
175175
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
176176
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
177177

178+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
179+
}
180+
181+
TEST(Converters, ATenEluConvertsCorrectly) {
182+
const auto graph = R"IR(
183+
graph(%x.1 : Tensor):
184+
%2 : float = prim::Constant[value=1.]()
185+
%3 : int = prim::Constant[value=1]()
186+
%result.2 : Tensor = aten::elu(%x.1, %2, %3, %3)
187+
return (%result.2))IR";
188+
189+
auto g = std::make_shared<torch::jit::Graph>();
190+
torch::jit::parseIR(graph, &*g);
191+
192+
auto in = at::randint(-5, 5, {1, 10, 1, 1}, {at::kCUDA});
193+
194+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
195+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
196+
197+
in = at::clone(in);
198+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
199+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
200+
178201
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
179202
}

0 commit comments

Comments
 (0)