@@ -29,6 +29,85 @@ TEST(Converters, ATenCatPureTensorConvertsCorrectly) {
2929 torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
3030}
3131
32+ TEST (Converters, ATenCatFloatIntConvertsCorrectly) {
33+ const auto graph = R"IR(
34+ graph(%0 : Tensor,
35+ %1 : Tensor):
36+ %2 : Tensor[] = prim::ListConstruct(%0, %1)
37+ %3 : int = prim::Constant[value=0]()
38+ %4 : Tensor = aten::cat(%2, %3)
39+ return (%4))IR" ;
40+
41+ auto g = std::make_shared<torch::jit::Graph>();
42+ torch::jit::parseIR (graph, g.get ());
43+
44+ auto in1 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kFloat );
45+ auto in2 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kInt );
46+
47+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
48+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, in2});
49+
50+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
51+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2});
52+
53+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
54+ }
55+
56+ TEST (Converters, ATenCatIntHalfIntHalfConvertsCorrectly) {
57+ const auto graph = R"IR(
58+ graph(%0 : Tensor,
59+ %1 : Tensor,
60+ %2 : Tensor,
61+ %3 : Tensor):
62+ %2 : Tensor[] = prim::ListConstruct(%0, %1, %2, %3)
63+ %3 : int = prim::Constant[value=0]()
64+ %4 : Tensor = aten::cat(%2, %3)
65+ return (%4))IR" ;
66+
67+ auto g = std::make_shared<torch::jit::Graph>();
68+ torch::jit::parseIR (graph, g.get ());
69+
70+ auto in1 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kInt );
71+ auto in2 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kHalf );
72+ auto in3 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kInt );
73+ auto in4 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kHalf );
74+
75+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
76+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, in2, in3, in4});
77+
78+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
79+ auto trt_results =
80+ torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2, in3, in4}, nvinfer1::DataType::kHALF );
81+
82+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
83+ }
84+
85+ TEST (Converters, ATenCatHalfIntFloatConvertsCorrectly) {
86+ const auto graph = R"IR(
87+ graph(%0 : Tensor,
88+ %1 : Tensor,
89+ %2 : Tensor):
90+ %2 : Tensor[] = prim::ListConstruct(%0, %1, %2)
91+ %3 : int = prim::Constant[value=0]()
92+ %4 : Tensor = aten::cat(%2, %3)
93+ return (%4))IR" ;
94+
95+ auto g = std::make_shared<torch::jit::Graph>();
96+ torch::jit::parseIR (graph, g.get ());
97+
98+ auto in1 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kInt );
99+ auto in2 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kHalf );
100+ auto in3 = at::randint (1 , 10 , {5 }, {at::kCUDA }).to (at::kFloat );
101+
102+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
103+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, in2, in3});
104+
105+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
106+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2, in3});
107+
108+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
109+ }
110+
32111TEST (Converters, ATenCatDiffTensorConvertsCorrectly) {
33112 const auto graph = R"IR(
34113 graph(%0 : Tensor,
0 commit comments