@@ -11,7 +11,8 @@ void pointwise_test_helper(
11
11
bool dynamicInput = false ,
12
12
std::vector<int64_t > shape1 = {5 },
13
13
std::vector<int64_t > shape2 = {5 },
14
- bool negative_input = false ) {
14
+ bool negative_input = false ,
15
+ bool int_tensors = false ) {
15
16
auto g = std::make_shared<torch::jit::Graph>();
16
17
torch::jit::parseIR (graph_ir, g.get ());
17
18
@@ -26,6 +27,11 @@ void pointwise_test_helper(
26
27
if (!singleInput) {
27
28
torch_inputs.push_back (at::randint (1 , 5 , shape2, {at::kCUDA }));
28
29
}
30
+ if (int_tensors){
31
+ for (size_t i = 0UL ; i < torch_inputs.size (); ++i){
32
+ torch_inputs[i] = torch_inputs[i].to (at::kInt );
33
+ }
34
+ }
29
35
auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
30
36
auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, torch_inputs);
31
37
@@ -126,6 +132,15 @@ TEST(Converters, ATenMulWithScalarConvertsCorrectly) {
126
132
pointwise_test_helper (graph, true );
127
133
}
128
134
135
+ TEST (Converters, ATenMulWithIntScalarConvertsCorrectly) {
136
+ const auto graph = R"IR(
137
+ graph(%0 : Tensor):
138
+ %scalar : int = prim::Constant[value=2]()
139
+ %1 : Tensor = aten::mul(%0, %scalar)
140
+ return (%1))IR" ;
141
+ pointwise_test_helper (graph, true , false , {5 }, {5 }, false , true );
142
+ }
143
+
129
144
TEST (Converters, ATenDivConvertsCorrectly) {
130
145
const auto graph = R"IR(
131
146
graph(%0 : Tensor, %1 : Tensor):
0 commit comments