@@ -18,8 +18,7 @@ TEST(Converters, ATenStackPureTensorConvertsCorrectly) {
1818 params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
1919 auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2});
2020
21- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (
22- jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), THRESHOLD_E5));
21+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], THRESHOLD_E5));
2322 };
2423 const auto graph = R"IR(
2524 graph(%0 : Tensor,
@@ -35,9 +34,17 @@ TEST(Converters, ATenStackPureTensorConvertsCorrectly) {
3534 %3 : int = prim::Constant[value=-1]()
3635 %4 : Tensor = aten::stack(%2, %3)
3736 return (%4))IR" ;
37+ const auto graph3 = R"IR(
38+ graph(%0 : Tensor,
39+ %1 : Tensor):
40+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
41+ %3 : int = prim::Constant[value=-2]()
42+ %4 : Tensor = aten::stack(%2, %3)
43+ return (%4))IR" ;
3844
3945 TestATenStackPureTensorConvertsCorrectly (graph);
4046 TestATenStackPureTensorConvertsCorrectly (graph2);
47+ TestATenStackPureTensorConvertsCorrectly (graph3);
4148}
4249
4350TEST (Converters, ATenStackPureTensorDynamicConvertsCorrectly) {
@@ -89,8 +96,7 @@ TEST(Converters, ATenStackDiffTensorConvertsCorrectly) {
8996 params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {in2});
9097 auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1});
9198
92- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (
93- jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), THRESHOLD_E5));
99+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], THRESHOLD_E5));
94100 };
95101 const auto graph = R"IR(
96102 graph(%0 : Tensor,
@@ -106,6 +112,14 @@ TEST(Converters, ATenStackDiffTensorConvertsCorrectly) {
106112 %3 : int = prim::Constant[value=-1]()
107113 %4 : Tensor = aten::stack(%2, %3)
108114 return (%4))IR" ;
115+ const auto graph3 = R"IR(
116+ graph(%0 : Tensor,
117+ %1 : Float(4, 4, 4, strides=[16, 4, 1])):
118+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
119+ %3 : int = prim::Constant[value=-3]()
120+ %4 : Tensor = aten::stack(%2, %3)
121+ return (%4))IR" ;
109122 TestATenStackDiffTensorConvertsCorrectly (graph);
110123 TestATenStackDiffTensorConvertsCorrectly (graph2);
124+ TestATenStackDiffTensorConvertsCorrectly (graph3);
111125}
0 commit comments