Skip to content

Commit f5aa404

Browse files
authored
Merge pull request #1095 from mfeliz-cruise/michael.feliz/int_scalar_mul
fix: support int tensor * int scaler in aten::mul
2 parents eec3172 + b701b42 commit f5aa404

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
425425
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
426426
// TODO: Remove with functionalization
427427
auto self = args[0].ITensorOrFreeze(ctx);
428-
auto otherScalar = args[1].unwrapToScalar().to<float>();
429-
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
428+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
430429
auto mul =
431430
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
432431
TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n);

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ void pointwise_test_helper(
1111
bool dynamicInput = false,
1212
std::vector<int64_t> shape1 = {5},
1313
std::vector<int64_t> shape2 = {5},
14-
bool negative_input = false) {
14+
bool negative_input = false,
15+
bool int_tensors = false) {
1516
auto g = std::make_shared<torch::jit::Graph>();
1617
torch::jit::parseIR(graph_ir, g.get());
1718

@@ -26,6 +27,11 @@ void pointwise_test_helper(
2627
if (!singleInput) {
2728
torch_inputs.push_back(at::randint(1, 5, shape2, {at::kCUDA}));
2829
}
30+
if(int_tensors){
31+
for(size_t i = 0UL; i < torch_inputs.size(); ++i){
32+
torch_inputs[i] = torch_inputs[i].to(at::kInt);
33+
}
34+
}
2935
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
3036
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, torch_inputs);
3137

@@ -126,6 +132,15 @@ TEST(Converters, ATenMulWithScalarConvertsCorrectly) {
126132
pointwise_test_helper(graph, true);
127133
}
128134

135+
TEST(Converters, ATenMulWithIntScalarConvertsCorrectly) {
136+
const auto graph = R"IR(
137+
graph(%0 : Tensor):
138+
%scalar : int = prim::Constant[value=2]()
139+
%1 : Tensor = aten::mul(%0, %scalar)
140+
return (%1))IR";
141+
pointwise_test_helper(graph, true, false, {5}, {5}, false, true);
142+
}
143+
129144
TEST(Converters, ATenDivConvertsCorrectly) {
130145
const auto graph = R"IR(
131146
graph(%0 : Tensor, %1 : Tensor):

0 commit comments

Comments
 (0)