Skip to content

Commit 3d1fbfd

Browse files
authored
Merge pull request #320 from NVIDIA/neg_index_sum
Support -1 indexing for aten::sum
2 parents 78c614a + 769bbc9 commit 3d1fbfd

File tree

2 files changed

+67
-3
lines changed

2 files changed

+67
-3
lines changed

core/conversion/converters/impl/reduce.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,23 @@ auto reduce_registrations TRTORCH_UNUSED =
8080
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
8181
auto in_tensor = args[0].ITensorOrFreeze(ctx);
8282
auto dims = args[1].unwrapToIntList();
83-
LOG_DEBUG("Dim to reduce:" << util::toDims(dims)); // Some abuse of toDim but just for debug info
83+
c10::List<int64_t> calculated_dims;
84+
auto in_dims = util::toVec(in_tensor->getDimensions());
85+
LOG_DEBUG("InDims " << in_dims); // Some abuse of toDim but just for debug info
86+
LOG_DEBUG(
87+
"Dim to reduce(original):" << util::toDims(dims)); // Some abuse of toDim but just for debug info
88+
for (size_t i = 0; i < dims.size(); i++) {
89+
auto dim_val = dims[i] < 0 ? (in_dims.size() + dims[i]) : dims[i];
90+
calculated_dims.push_back(dim_val);
91+
}
92+
93+
LOG_DEBUG(
94+
"Dim to reduce(converted):"
95+
<< util::toDims(calculated_dims)); // Some abuse of toDim but just for debug info
8496

8597
uint32_t axis_mask = 0;
86-
for (size_t d = 0; d < dims.size(); d++) {
87-
axis_mask |= 1 << dims[d];
98+
for (size_t d = 0; d < calculated_dims.size(); d++) {
99+
axis_mask |= 1 << calculated_dims[d];
88100
}
89101
LOG_DEBUG("Axis Mask" << std::bitset<32>(axis_mask));
90102

tests/core/conversion/converters/test_reduce.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,58 @@ converts_keepdims_correctly(mean, Mean);
134134

135135
#undef converts_keepdims_correctly
136136

137+
TEST(Converters, ATenSumDimNegOneIndexConvertsCorrectly) {
138+
const auto graph = R"IR(
139+
graph(%0 : Tensor):
140+
%1 : int = prim::Constant[value=-1]()
141+
%2 : int[] = prim::ListConstruct(%1)
142+
%3 : bool = prim::Constant[value=0]()
143+
%4 : None = prim::Constant()
144+
%5 : Tensor = aten::sum(%0, %2, %3, %4)
145+
return (%5))IR";
146+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
147+
test_body(graph, in);
148+
}
149+
150+
TEST(Converters, ATenSumDimNegOneIndexKeepDimsConvertsCorrectly) {
151+
const auto graph = R"IR(
152+
graph(%0 : Tensor):
153+
%1 : int = prim::Constant[value=-1]()
154+
%2 : int[] = prim::ListConstruct(%1)
155+
%3 : bool = prim::Constant[value=1]()
156+
%4 : None = prim::Constant()
157+
%5 : Tensor = aten::sum(%0, %2, %3, %4)
158+
return (%5))IR";
159+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
160+
test_body(graph, in);
161+
}
162+
163+
TEST(Converters, ATenSumDimNegIndexConvertsCorrectly) {
164+
const auto graph = R"IR(
165+
graph(%0 : Tensor):
166+
%1 : int = prim::Constant[value=-2]()
167+
%2 : int[] = prim::ListConstruct(%1)
168+
%3 : bool = prim::Constant[value=0]()
169+
%4 : None = prim::Constant()
170+
%5 : Tensor = aten::sum(%0, %2, %3, %4)
171+
return (%5))IR";
172+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
173+
test_body(graph, in);
174+
}
175+
176+
TEST(Converters, ATenSumDimNegIndexKeepDimsConvertsCorrectly) {
177+
const auto graph = R"IR(
178+
graph(%0 : Tensor):
179+
%1 : int = prim::Constant[value=-2]()
180+
%2 : int[] = prim::ListConstruct(%1)
181+
%3 : bool = prim::Constant[value=1]()
182+
%4 : None = prim::Constant()
183+
%5 : Tensor = aten::sum(%0, %2, %3, %4)
184+
return (%5))IR";
185+
auto in = at::randint(-5, 5, {4, 4, 4}, at::kCUDA);
186+
test_body(graph, in);
187+
}
188+
137189
TEST(Converters, ATenProdDimConvertsCorrectly) {
138190
const auto graph = R"IR(
139191
graph(%0 : Tensor):

0 commit comments

Comments
 (0)