Skip to content

Commit 17f2e8f

Browse files
authored
Merge pull request #316 from NVIDIA/index_-1
support index=-1, and fix bug when nDims=2 or 3 in softmax
2 parents 3d1fbfd + abc29a2 commit 17f2e8f

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

core/conversion/converters/impl/softmax.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,17 @@ static auto softmax_registrations TRTORCH_UNUSED = RegisterNodeConversionPattern
2626
}
2727

2828
int64_t dim = args[1].IValue()->toInt();
29+
LOG_DEBUG("Softmax original dim " << dim);
30+
if (dim < 0) {
31+
dim = shape.size() + dim;
32+
}
33+
LOG_DEBUG("Softmax converted dim " << dim);
2934
auto softmax = ctx->net->addSoftMax(*in);
3035

3136
TRTORCH_CHECK(softmax, "Unable to create softmax layer from node: " << *n);
3237
LOG_DEBUG("Disregarding dtype argument");
3338

34-
if (shape.size() > 3) {
39+
if (shape.size() > 1) {
3540
softmax->setAxes(1 << (dim));
3641
} else {
3742
// When there is no batch dimension

tests/core/conversion/converters/test_softmax.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,55 @@ TEST(Converters, ATenSoftmaxNDConvertsCorrectlyAbove3DIndex) {
7676

7777
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
7878
}
79+
80+
TEST(Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveOneIndex) {
81+
const auto graph = R"IR(
82+
graph(%0 : Tensor):
83+
%1 : None = prim::Constant()
84+
%2 : int = prim::Constant[value=-1]()
85+
%3 : Tensor = aten::softmax(%0, %2, %1)
86+
return (%3))IR";
87+
88+
auto g = std::make_shared<torch::jit::Graph>();
89+
torch::jit::parseIR(graph, &*g);
90+
91+
auto in = at::randint(0, 5, {1, 2, 2, 2, 2}, {at::kCUDA});
92+
93+
auto jit_in = at::clone(in);
94+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
95+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
96+
97+
auto trt_in = at::clone(in);
98+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
99+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
100+
101+
auto trt = trt_results[0].reshape_as(jit_results[0]);
102+
103+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
104+
}
105+
106+
TEST(Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveIndex) {
107+
const auto graph = R"IR(
108+
graph(%0 : Tensor):
109+
%1 : None = prim::Constant()
110+
%2 : int = prim::Constant[value=-2]()
111+
%3 : Tensor = aten::softmax(%0, %2, %1)
112+
return (%3))IR";
113+
114+
auto g = std::make_shared<torch::jit::Graph>();
115+
torch::jit::parseIR(graph, &*g);
116+
117+
auto in = at::randint(0, 5, {1, 2, 2, 2, 2}, {at::kCUDA});
118+
119+
auto jit_in = at::clone(in);
120+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
121+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
122+
123+
auto trt_in = at::clone(in);
124+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
125+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
126+
127+
auto trt = trt_results[0].reshape_as(jit_results[0]);
128+
129+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
130+
}

0 commit comments

Comments
 (0)