Skip to content

Commit 24a780f

Browse files
authored
Merge pull request #397 from guoruoqian/aten_size_fix_bug
fix bug, when dim of aten::size.int(Tensor self, int dim) -> (int) is…
2 parents 4da65b3 + efc8202 commit 24a780f

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,18 @@ auto aten_registrations TRTORCH_UNUSED =
180180
auto dim = args.at(n->input(1)).unwrapToInt();
181181
if (tensor_var.isITensor()) {
182182
auto tensor = tensor_var.ITensor();
183-
return util::toVec(tensor->getDimensions())[dim];
183+
auto dims = util::toVec(tensor->getDimensions());
184+
auto nbDims = tensor->getDimensions().nbDims;
185+
if (dim < 0) {
186+
dim += nbDims;
187+
}
188+
return dims[dim];
184189
} else {
185190
auto tensor = tensor_var.unwrapToTensor();
191+
auto nbDims = tensor.sizes().size();
192+
if (dim < 0) {
193+
dim += nbDims;
194+
}
186195
return tensor.sizes()[dim];
187196
}
188197
}

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,31 @@ TEST(Evaluators, ATenArangeStartEndStepFloatEvaluatesCorrectly) {
180180
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0].toTensor(), trt_results[0].toTensor(), 2e-6));
181181
}
182182

183+
TEST(Evaluators, ATenSizeNegativeConvertsCorrectly) {
184+
const auto graph = R"IR(
185+
graph(%0 : Tensor):
186+
%1 : int = prim::Constant[value=-1]()
187+
%2 : int = prim::Constant[value=-2]()
188+
%3 : int = aten::size(%0, %1)
189+
%4 : int = aten::size(%0, %2)
190+
%5 : int[] = prim::ListConstruct(%3, %4)
191+
%6 : Tensor = aten::view(%0, %5)
192+
return (%6))IR";
193+
194+
auto g = std::make_shared<torch::jit::Graph>();
195+
torch::jit::parseIR(graph, &*g);
196+
197+
auto in = at::randint(1, 10, {3, 3}, {at::kCUDA});
198+
199+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
200+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
201+
202+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
203+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
204+
205+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
206+
}
207+
183208
TEST(Evaluators, FloorIntIntEvaluatesCorrectly) {
184209
const auto graph = R"IR(
185210
graph():

0 commit comments

Comments
 (0)