44#include " tests/util/util.h"
55#include " core/compiler.h"
66
7- void pointwise_test_helper (std::string graph_ir) {
7+ void pointwise_test_helper (std::string graph_ir, bool singleInput ) {
88 auto g = std::make_shared<torch::jit::Graph>();
99 torch::jit::parseIR (graph_ir, &*g);
10-
11- auto in0 = at::randint (1 , 5 , {5 }, {at::kCUDA });
12- auto in1 = at::randint (1 , 5 , {5 }, {at::kCUDA });
10+
11+ // singleInput case is enabled when elementwise operation is performed
12+ // with an input and a constant embedded in graph
13+ std::vector<at::Tensor> torch_inputs;
14+ torch_inputs.push_back (at::randint (1 , 5 , {5 }, {at::kCUDA }));
15+ if (!singleInput) {
16+ torch_inputs.push_back (at::randint (1 , 5 , {5 }, {at::kCUDA }));
17+ }
1318 auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
14- auto jit_results = trtorch::tests::util::RunGraph (g, params, {in0, in1});
19+ auto jit_results = trtorch::tests::util::RunGraph (g, params, torch_inputs);
20+
21+ std::vector<at::Tensor> trt_inputs;
22+ for (auto in : torch_inputs) {
23+ trt_inputs.push_back (at::clone (in));
24+ }
1525
16- in0 = at::clone (in0);
17- in1 = at::clone (in1);
1826 params = trtorch::core::conversion::get_named_params (g->inputs (), {});
19- auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in0, in1} );
27+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, trt_inputs );
2028
2129 ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
2230}
2331
2432
25-
2633TEST (Converters, ATenAddConvertsCorrectly) {
2734 const auto graph = R"IR(
2835 graph(%0 : Tensor, %1 : Tensor):
2936 %2 : int = prim::Constant[value=1]()
3037 %3 : Tensor = aten::add(%0, %1, %2)
3138 return (%3))IR" ;
32- pointwise_test_helper (graph);
39+ pointwise_test_helper (graph, false );
3340}
3441
3542
@@ -39,7 +46,7 @@ TEST(Converters, ATenAddConvertsCorrectly) {
3946// %2 : int = prim::Constant[value=2]()
4047// %3 : Tensor = aten::add(%0, %1, %2)
4148// return (%3))IR";
42- // pointwise_test_helper(graph);
49+ // pointwise_test_helper(graph, false );
4350// }
4451
4552TEST (Converters, ATenSubConvertsCorrectly) {
@@ -48,7 +55,7 @@ TEST(Converters, ATenSubConvertsCorrectly) {
4855 %2 : int = prim::Constant[value=1]()
4956 %3 : Tensor = aten::sub(%0, %1, %2)
5057 return (%3))IR" ;
51- pointwise_test_helper (graph);
58+ pointwise_test_helper (graph, false );
5259}
5360
5461// TEST(Converters, ATenSubWithScaleConvertsCorrectly) {
@@ -57,21 +64,38 @@ TEST(Converters, ATenSubConvertsCorrectly) {
5764// %2 : float = prim::Constant[value=0.5]()
5865// %3 : Tensor = aten::add(%0, %1, %2)
5966// return (%3))IR";
60- // pointwise_test_helper(graph);
67+ // pointwise_test_helper(graph, false );
6168// }
6269
6370TEST (Converters, ATenMulConvertsCorrectly) {
6471 const auto graph = R"IR(
6572 graph(%0 : Tensor, %1 : Tensor):
6673 %2 : Tensor = aten::mul(%0, %1)
6774 return (%2))IR" ;
68- pointwise_test_helper (graph);
75+ pointwise_test_helper (graph, false );
6976}
7077
7178TEST (Converters, ATenDivConvertsCorrectly) {
7279 const auto graph = R"IR(
7380 graph(%0 : Tensor, %1 : Tensor):
7481 %2 : Tensor = aten::div(%0, %1)
7582 return (%2))IR" ;
76- pointwise_test_helper (graph);
83+ pointwise_test_helper (graph, false );
84+ }
85+
86+ TEST (Converters, ATenPowTensorConvertsCorrectly) {
87+ const auto graph = R"IR(
88+ graph(%x.1 : Tensor, %x2.1 : Tensor):
89+ %3 : Tensor = aten::pow(%x.1, %x2.1)
90+ return (%3))IR" ;
91+ pointwise_test_helper (graph, false );
92+ }
93+
94+ TEST (Converters, ATenPowScalarConvertsCorrectly) {
95+ const auto graph = R"IR(
96+ graph(%x.1 : Tensor):
97+ %2 : int = prim::Constant[value=2]()
98+ %3 : Tensor = aten::pow(%x.1, %2)
99+ return (%3))IR" ;
100+ pointwise_test_helper (graph, true );
77101}
0 commit comments