From 53e3e2b7546974ad93cd88f3d812b56d37d22eb0 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Fri, 5 Sep 2025 07:12:33 -0700 Subject: [PATCH] Fix crash in optimized log softmax (#13953) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Fixes https://github.com/pytorch/executorch/issues/13551 🐛 Problem The optimized log_softmax kernel crashed with fatal assertions when encountering unsupported double precision dtypes: F 00:00:00.005478 executorch:op_log_softmax.cpp:156] assert failed (false): Unhandled out dtype 7 ✅ Solution Replaced fatal ET_CHECK_MSG with graceful ET_KERNEL_CHECK error handling: Before: Program termination on unsupported dtypesAfter: Returns error codes, allows program to continue 📝 Changes kernels/optimized/cpu/op_log_softmax.cpp: - Converted log_softmax_wrapper to return bool success status - Replaced fatal assertions with graceful error returns - Added proper error state handling via ET_KERNEL_CHECK kernels/test/op_log_softmax_test.cpp: - Added DoubleCase test for comprehensive double precision validation - Implemented smart conditional testing: runs full validation on portable kernels, verifies graceful failure + skips on optimized kernels - Uses expect_failure() for CI-friendly error state handling 🎯 Results | Kernel | Double Precision Support | Behavior | CI Status | |-----------|--------------------------|-----------------------------|------------| | Portable | ✅ Full support | Test passes with validation | 🟢 PASSED | | Optimized | ❌ Not supported | Graceful error + skip | 🟢 SKIPPED | 🛡️ Benefits - No more crashes: Applications can handle unsupported operations gracefully - Better testability: Error conditions can now be tested and validated - CI compatibility: Green builds for both supported and unsupported operations - Regression protection: Prevents future reintroduction of fatal errors Testing: Verified on both portable and optimized kernel test suites ✅ Reviewed By: manuelcandales Differential Revision: D81703958 Pulled By: mergennachin --- kernels/optimized/cpu/op_log_softmax.cpp | 20 +++---- kernels/test/op_log_softmax_test.cpp | 66 ++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 12 deletions(-) diff --git a/kernels/optimized/cpu/op_log_softmax.cpp b/kernels/optimized/cpu/op_log_softmax.cpp index c4eac7594f3..f56b0a37de2 100644 --- a/kernels/optimized/cpu/op_log_softmax.cpp +++ b/kernels/optimized/cpu/op_log_softmax.cpp @@ -103,18 +103,15 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) { template < typename OUT_T, std::enable_if_t::value, bool> = true> -void log_softmax_wrapper(const Tensor& X, int64_t dim, Tensor& out) { +bool log_softmax_wrapper(const Tensor& X, int64_t dim, Tensor& out) { auto input_scalar_type = X.scalar_type(); switch (input_scalar_type) { // TODO: support Double as well case ScalarType::Float: log_softmax_kernel(X, dim, out); - break; + return true; default: - ET_CHECK_MSG( - false, - "Unhandled input dtype %" PRId8, - static_cast(input_scalar_type)); + return false; // Unsupported input dtype } } } // namespace @@ -146,14 +143,13 @@ Tensor& opt_log_softmax_out( auto out_scalar_type = out.scalar_type(); switch (out_scalar_type) { // TODO: support Double as well - case ScalarType::Float: - log_softmax_wrapper(self, dim, out); + case ScalarType::Float: { + bool success = log_softmax_wrapper(self, dim, out); + ET_KERNEL_CHECK(context, success, InvalidArgument, out); break; + } default: - ET_CHECK_MSG( - false, - "Unhandled out dtype %" PRId8, - static_cast(out_scalar_type)); + ET_KERNEL_CHECK(context, false, InvalidArgument, out); } return out; } diff --git a/kernels/test/op_log_softmax_test.cpp b/kernels/test/op_log_softmax_test.cpp index 1b01ff8a78d..3bcbee96a1c 100644 --- a/kernels/test/op_log_softmax_test.cpp +++ b/kernels/test/op_log_softmax_test.cpp @@ -447,3 +447,69 @@ TEST_F(OpLogSoftmaxOutTest, DynamicShapeUnbound) { Tensor ret = op_log_softmax_out(x, 1, false, out); EXPECT_TENSOR_CLOSE(out, expected_result); } + +TEST_F(OpLogSoftmaxOutTest, DoubleCase) { + TensorFactory tf; + + // Test case with specific inputs: + // Input tensor: torch.float64 (8, 5, 7) + // Dim: 2 + // half_to_float: False + Tensor input = tf.zeros({8, 5, 7}); + auto in_data = input.mutable_data_ptr(); + + // Fill with some test data (sequential values scaled) + for (int i = 0; i < 8 * 5 * 7; i++) { + in_data[i] = static_cast(i) * 0.01; + } + + // Output tensor with same shape + Tensor out = tf.zeros({8, 5, 7}); + + // Apply log_softmax along dimension 2 (the last dimension with size 7) + op_log_softmax_out(input, /*dim=*/2, /*half_to_float=*/false, out); + + if (!SupportedFeatures::get()->op_log_softmax_dtype_double) { + // For optimized kernels, we expect the call above to fail gracefully + expect_failure(); + GTEST_SKIP() << "This kernel does not support dtype double"; + } + + // Verify output dimensions + EXPECT_EQ(out.size(0), 8); + EXPECT_EQ(out.size(1), 5); + EXPECT_EQ(out.size(2), 7); + + // Verify that output has reasonable values + auto out_data = out.const_data_ptr(); + + // Check for NaN or Inf values + for (int i = 0; i < 8 * 5 * 7; i++) { + EXPECT_FALSE(std::isnan(out_data[i])) + << "Output should not contain NaN at index " << i; + EXPECT_FALSE(std::isinf(out_data[i])) + << "Output should not contain Inf at index " << i; + } + + // For log_softmax, all values should be <= 0 (since softmax values are <= 1, + // log is <= 0) + for (int i = 0; i < 8 * 5 * 7; i++) { + EXPECT_LE(out_data[i], 0.0) + << "Log softmax values should be <= 0 at index " << i; + } + + // Verify that each slice along dimension 2 sums to approximately 1 when exp'd + // This tests the core property of softmax: sum(softmax(x)) = 1 + for (int batch = 0; batch < 8; batch++) { + for (int channel = 0; channel < 5; channel++) { + double sum_exp = 0.0; + for (int dim2 = 0; dim2 < 7; dim2++) { + int idx = batch * 5 * 7 + channel * 7 + dim2; + sum_exp += std::exp(out_data[idx]); + } + EXPECT_NEAR(sum_exp, 1.0, 1e-6) + << "Sum of exp(log_softmax) should be 1.0 for batch=" << batch + << ", channel=" << channel; + } + } +}