@@ -76,3 +76,55 @@ TEST(Converters, ATenSoftmaxNDConvertsCorrectlyAbove3DIndex) {
76
76
77
77
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
78
78
}
79
+
80
+ TEST (Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveOneIndex) {
81
+ const auto graph = R"IR(
82
+ graph(%0 : Tensor):
83
+ %1 : None = prim::Constant()
84
+ %2 : int = prim::Constant[value=-1]()
85
+ %3 : Tensor = aten::softmax(%0, %2, %1)
86
+ return (%3))IR" ;
87
+
88
+ auto g = std::make_shared<torch::jit::Graph>();
89
+ torch::jit::parseIR (graph, &*g);
90
+
91
+ auto in = at::randint (0 , 5 , {1 , 2 , 2 , 2 , 2 }, {at::kCUDA });
92
+
93
+ auto jit_in = at::clone (in);
94
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
95
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
96
+
97
+ auto trt_in = at::clone (in);
98
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
99
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
100
+
101
+ auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
102
+
103
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
104
+ }
105
+
106
+ TEST (Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveIndex) {
107
+ const auto graph = R"IR(
108
+ graph(%0 : Tensor):
109
+ %1 : None = prim::Constant()
110
+ %2 : int = prim::Constant[value=-2]()
111
+ %3 : Tensor = aten::softmax(%0, %2, %1)
112
+ return (%3))IR" ;
113
+
114
+ auto g = std::make_shared<torch::jit::Graph>();
115
+ torch::jit::parseIR (graph, &*g);
116
+
117
+ auto in = at::randint (0 , 5 , {1 , 2 , 2 , 2 , 2 }, {at::kCUDA });
118
+
119
+ auto jit_in = at::clone (in);
120
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
121
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
122
+
123
+ auto trt_in = at::clone (in);
124
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
125
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
126
+
127
+ auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
128
+
129
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
130
+ }
0 commit comments