Skip to content

Commit 327191c

Browse files
committed
fix bug, when dim of aten::size.int(Tensor self, int dim) -> (int) is negative
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent 6bb9fbf commit 327191c

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,16 @@ auto aten_registrations TRTORCH_UNUSED =
180180
auto dim = args.at(n->input(1)).unwrapToInt();
181181
if (tensor_var.isITensor()) {
182182
auto tensor = tensor_var.ITensor();
183-
return util::toVec(tensor->getDimensions())[dim];
183+
auto dims = util::toVec(tensor->getDimensions());
184+
auto nbDims = tensor->getDimensions().nbDims;
185+
if (dim < 0)
186+
dim += nbDims;
187+
return dims[dim];
184188
} else {
185189
auto tensor = tensor_var.unwrapToTensor();
190+
auto nbDims = tensor.sizes().size();
191+
if (dim < 0)
192+
dim += nbDims;
186193
return tensor.sizes()[dim];
187194
}
188195
}

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,29 @@ TEST(Evaluators, ZerosDataTypeEvaluatesCorrectly) {
7575
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});
7676

7777
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
78+
}
79+
80+
TEST(Evaluators, SizeConvertsCorrectly) {
81+
const auto graph = R"IR(
82+
graph(%0 : Tensor):
83+
%1 : int = prim::Constant[value=-1]()
84+
%2 : int = prim::Constant[value=-2]()
85+
%3 : int = aten::size(%0, %1)
86+
%4 : int = aten::size(%0, %2)
87+
%5 : int[] = prim::ListConstruct(%3, %4)
88+
%6 : Tensor = aten::view(%0, %5)
89+
return (%6))IR";
90+
91+
auto g = std::make_shared<torch::jit::Graph>();
92+
torch::jit::parseIR(graph, &*g);
93+
94+
auto in = at::randint(1, 10, {3, 3}, {at::kCUDA});
95+
96+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
97+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
98+
99+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
100+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
101+
102+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
78103
}

0 commit comments

Comments
 (0)