Skip to content

Commit 96850b6

Browse files
inocsinnarendasan
authored andcommitted
support sum and softmax with reduce_dim = -1
Signed-off-by: inocsin <[email protected]>
1 parent 7e467a6 commit 96850b6

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

core/conversion/converters/impl/reduce.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,20 @@ auto reduce_registrations TRTORCH_UNUSED =
8080
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
8181
auto in_tensor = args[0].ITensorOrFreeze(ctx);
8282
auto dims = args[1].unwrapToIntList();
83-
LOG_DEBUG("Dim to reduce:" << util::toDims(dims)); // Some abuse of toDim but just for debug info
83+
c10::List<int64_t> dims_copy;
84+
auto in_dims = util::toVec(in_tensor->getDimensions());
85+
LOG_DEBUG("InDims " << in_dims); // Some abuse of toDim but just for debug info
86+
LOG_DEBUG("Dim to reduce(original):" << util::toDims(dims)); // Some abuse of toDim but just for debug info
87+
for (int i = 0; i < dims.size(); i++) {
88+
auto dim_val = dims[i] == -1 ? (in_dims.size() - 1) : dims[i];
89+
dims_copy.push_back(dim_val);
90+
}
91+
92+
LOG_DEBUG("Dim to reduce(converted):" << util::toDims(dims_copy)); // Some abuse of toDim but just for debug info
8493

8594
uint32_t axis_mask = 0;
86-
for (size_t d = 0; d < dims.size(); d++) {
87-
axis_mask |= 1 << dims[d];
95+
for (size_t d = 0; d < dims_copy.size(); d++) {
96+
axis_mask |= 1 << dims_copy[d];
8897
}
8998
LOG_DEBUG("Axis Mask" << std::bitset<32>(axis_mask));
9099

core/conversion/converters/impl/softmax.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ static auto softmax_registrations TRTORCH_UNUSED = RegisterNodeConversionPattern
2626
}
2727

2828
int64_t dim = args[1].IValue()->toInt();
29+
if (dim == -1) {
30+
dim = shape.size() - 1;
31+
}
2932
auto softmax = ctx->net->addSoftMax(*in);
3033

3134
TRTORCH_CHECK(softmax, "Unable to create softmax layer from node: " << *n);

0 commit comments

Comments
 (0)