@@ -1122,6 +1122,34 @@ TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) {
1122
1122
}
1123
1123
}
1124
1124
1125
+ TEST (Converters, ATenUnbindEvaluatedTensor) {
1126
+ const auto graph = R"IR(
1127
+ graph(%x.1 : Tensor):
1128
+ %2 : None = prim::Constant()
1129
+ %3 : int[] = aten::size(%x.1)
1130
+ %z.1 : Tensor = aten::zeros(%3, %2, %2, %2, %2)
1131
+ %5 : int = prim::Constant[value=-1]()
1132
+ %6 : Tensor[] = aten::unbind(%z.1, %5)
1133
+ %o1.1 : Tensor, %o2.1 : Tensor = prim::ListUnpack(%6)
1134
+ return (%o1.1, %o2.1))IR" ;
1135
+
1136
+ auto in = at::randint (1 , 10 , {2 }, {at::kCUDA });
1137
+
1138
+ auto g = std::make_shared<torch::jit::Graph>();
1139
+
1140
+ torch::jit::parseIR (graph, g.get ());
1141
+
1142
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
1143
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in});
1144
+
1145
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in});
1146
+
1147
+ for (size_t i = 0 ; i < jit_results.size (); i++) {
1148
+ auto trt = trt_results[i];
1149
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[i].cuda (), trt, 2e-6 ));
1150
+ }
1151
+ }
1152
+
1125
1153
TEST (Converters, ScatterValueConvertsCorrectly) {
1126
1154
const auto graph = R"IR(
1127
1155
graph(%data : Tensor,
0 commit comments