@@ -175,5 +175,28 @@ TEST(Converters, ATenLeakyReluConvertsCorrectly) {
175
175
params = trtorch::core::conversion::get_named_params (g->inputs (), {});
176
176
auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
177
177
178
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
179
+ }
180
+
181
+ TEST (Converters, ATenEluConvertsCorrectly) {
182
+ const auto graph = R"IR(
183
+ graph(%x.1 : Tensor):
184
+ %2 : float = prim::Constant[value=1.]()
185
+ %3 : int = prim::Constant[value=1]()
186
+ %result.2 : Tensor = aten::elu(%x.1, %2, %3, %3)
187
+ return (%result.2))IR" ;
188
+
189
+ auto g = std::make_shared<torch::jit::Graph>();
190
+ torch::jit::parseIR (graph, &*g);
191
+
192
+ auto in = at::randint (-5 , 5 , {1 , 10 , 1 , 1 }, {at::kCUDA });
193
+
194
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
195
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
196
+
197
+ in = at::clone (in);
198
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
199
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
200
+
178
201
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
179
202
}
0 commit comments