Skip to content

Commit 6ab0b50

Browse files
author
Github Executorch
committed
Fix log_softmax along non-contiguous dim
#8382 certainly didn't fix this problem (and added it on x86), but I don't think it was correct on ARM prior to that either. Added a regression test. Differential Revision: [D69928884](https://our.internmc.facebook.com/intern/diff/D69928884/) ghstack-source-id: 267433500 Pull Request resolved: #8595
1 parent 2fff01a commit 6ab0b50

File tree

2 files changed

+69
-8
lines changed

2 files changed

+69
-8
lines changed

kernels/optimized/cpu/op_log_softmax.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,20 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
7575
static_assert(
7676
std::is_same_v<OUT_T, float>,
7777
"Below loop actually only supports float.");
78-
const VecIn max_input_vec(max_input);
79-
for (; d + VecOut::size() < dim_size; d += VecOut::size()) {
80-
auto index = d * dim_stride;
81-
auto in = VecIn::loadu(&input_data[index]);
82-
auto out_ = (in - max_input_vec).exp();
83-
out_.store(&output_data[index]);
78+
// It is not correct to vectorize if dim is not contiguous!
79+
if (dim_stride == 1) {
80+
const VecIn max_input_vec(max_input);
81+
for (; d + VecOut::size() < dim_size; d += VecOut::size()) {
82+
auto index = d * dim_stride;
83+
auto in = VecIn::loadu(&input_data[index]);
84+
auto out_ = (in - max_input_vec).exp();
85+
out_.store(&output_data[index]);
8486
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)
85-
temp_sum += vaddvq_f32(out_);
87+
temp_sum += vaddvq_f32(out_);
8688
#else
87-
temp_sum += at::vec::vec_reduce_all<float>(std::plus<VecOut>(), out_);
89+
temp_sum += at::vec::vec_reduce_all<float>(std::plus<VecOut>(), out_);
8890
#endif
91+
}
8992
}
9093
for (; d < dim_size; ++d) {
9194
output_data[d * dim_stride] =

kernels/test/op_log_softmax_test.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,60 @@ class OpLogSoftmaxOutTest : public OperatorTest {
7272
EXPECT_TENSOR_CLOSE(out, expected);
7373
}
7474
}
75+
76+
template <class CTYPE, executorch::aten::ScalarType DTYPE>
77+
void test_dtype_noncontiguous_dim() {
78+
TensorFactory<DTYPE> tf;
79+
80+
// clang-format off Dim 0 must be longer than the vector width of
81+
// the machine (for float, this is 4 for ARM64 and 8 for AVX2) to
82+
// exhibit problems.
83+
Tensor x = tf.make(
84+
{9, 3},
85+
{
86+
0, 9, 18,
87+
1, 10, 19,
88+
2, 11, 20,
89+
3, 12, 21,
90+
4, 13, 22,
91+
5, 14, 23,
92+
6, 15, 24,
93+
7, 16, 25,
94+
8, 17, 26,
95+
});
96+
// clang-format on
97+
98+
Tensor out = tf.zeros({9, 3});
99+
100+
op_log_softmax_out(x, /*dim=*/0, /*half_to_float*/ false, out);
101+
102+
// clang-format off
103+
Tensor expected = tf.make(
104+
{9, 3},
105+
{
106+
-8.45855, -8.45855, -8.45855,
107+
-7.45855, -7.45855, -7.45855,
108+
-6.45855, -6.45855, -6.45855,
109+
-5.45855, -5.45855, -5.45855,
110+
-4.45855, -4.45855, -4.45855,
111+
-3.45855, -3.45855, -3.45855,
112+
-2.45855, -2.45855, -2.45855,
113+
-1.45855, -1.45855, -1.45855,
114+
-0.458552, -0.458552, -0.458552
115+
});
116+
// clang-format on
117+
118+
if constexpr (DTYPE == ScalarType::BFloat16) {
119+
EXPECT_TENSOR_CLOSE_WITH_TOL(
120+
out,
121+
expected,
122+
1e-2,
123+
executorch::runtime::testing::internal::kDefaultAtol);
124+
} else {
125+
EXPECT_TENSOR_CLOSE(out, expected);
126+
}
127+
}
128+
75129
};
76130

77131
TEST_F(OpLogSoftmaxOutTest, Smoke) {
@@ -101,6 +155,10 @@ TEST_F(OpLogSoftmaxOutTest, AllDtypesSupported) {
101155
#undef TEST_ENTRY
102156
}
103157

158+
TEST_F(OpLogSoftmaxOutTest, NonContiguous) {
159+
test_dtype_noncontiguous_dim<float, ScalarType::Float>();
160+
}
161+
104162
TEST_F(OpLogSoftmaxOutTest, MismatchedDimensionsDies) {
105163
if (SupportedFeatures::get()->is_aten) {
106164
GTEST_SKIP() << "ATen currently supports mismatched dimensions";

0 commit comments

Comments
 (0)