@@ -72,6 +72,59 @@ class OpLogSoftmaxOutTest : public OperatorTest {
7272 EXPECT_TENSOR_CLOSE (out, expected);
7373 }
7474 }
75+
76+ template <class CTYPE , executorch::aten::ScalarType DTYPE>
77+ void test_dtype_noncontiguous_dim () {
78+ TensorFactory<DTYPE> tf;
79+
80+ // Dim 0 must be longer than the vector width of the machine (for
81+ // float, this is 4 for ARM64 and 8 for AVX2) to exhibit problems.
82+ // clang-format off
83+ Tensor x = tf.make (
84+ {9 , 3 },
85+ {
86+ 0 , 9 , 18 ,
87+ 1 , 10 , 19 ,
88+ 2 , 11 , 20 ,
89+ 3 , 12 , 21 ,
90+ 4 , 13 , 22 ,
91+ 5 , 14 , 23 ,
92+ 6 , 15 , 24 ,
93+ 7 , 16 , 25 ,
94+ 8 , 17 , 26 ,
95+ });
96+ // clang-format on
97+
98+ Tensor out = tf.zeros ({9 , 3 });
99+
100+ op_log_softmax_out (x, /* dim=*/ 0 , /* half_to_float*/ false , out);
101+
102+ // clang-format off
103+ Tensor expected = tf.make (
104+ {9 , 3 },
105+ {
106+ -8.45855 , -8.45855 , -8.45855 ,
107+ -7.45855 , -7.45855 , -7.45855 ,
108+ -6.45855 , -6.45855 , -6.45855 ,
109+ -5.45855 , -5.45855 , -5.45855 ,
110+ -4.45855 , -4.45855 , -4.45855 ,
111+ -3.45855 , -3.45855 , -3.45855 ,
112+ -2.45855 , -2.45855 , -2.45855 ,
113+ -1.45855 , -1.45855 , -1.45855 ,
114+ -0.458552 , -0.458552 , -0.458552
115+ });
116+ // clang-format on
117+
118+ if constexpr (DTYPE == ScalarType::BFloat16) {
119+ EXPECT_TENSOR_CLOSE_WITH_TOL (
120+ out,
121+ expected,
122+ 1e-2 ,
123+ executorch::runtime::testing::internal::kDefaultAtol );
124+ } else {
125+ EXPECT_TENSOR_CLOSE (out, expected);
126+ }
127+ }
75128};
76129
77130TEST_F (OpLogSoftmaxOutTest, Smoke) {
@@ -101,6 +154,10 @@ TEST_F(OpLogSoftmaxOutTest, AllDtypesSupported) {
101154#undef TEST_ENTRY
102155}
103156
157+ TEST_F (OpLogSoftmaxOutTest, NonContiguous) {
158+ test_dtype_noncontiguous_dim<float , ScalarType::Float>();
159+ }
160+
104161TEST_F (OpLogSoftmaxOutTest, MismatchedDimensionsDies) {
105162 if (SupportedFeatures::get ()->is_aten ) {
106163 GTEST_SKIP () << " ATen currently supports mismatched dimensions" ;
0 commit comments