Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions kernels/optimized/cpu/op_log_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,15 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
template <
typename OUT_T,
std::enable_if_t<std::is_floating_point<OUT_T>::value, bool> = true>
void log_softmax_wrapper(const Tensor& X, int64_t dim, Tensor& out) {
bool log_softmax_wrapper(const Tensor& X, int64_t dim, Tensor& out) {
auto input_scalar_type = X.scalar_type();
switch (input_scalar_type) {
// TODO: support Double as well
case ScalarType::Float:
log_softmax_kernel<float, OUT_T>(X, dim, out);
break;
return true;
default:
ET_CHECK_MSG(
false,
"Unhandled input dtype %" PRId8,
static_cast<int8_t>(input_scalar_type));
return false; // Unsupported input dtype
}
}
} // namespace
Expand Down Expand Up @@ -146,14 +143,13 @@ Tensor& opt_log_softmax_out(
auto out_scalar_type = out.scalar_type();
switch (out_scalar_type) {
// TODO: support Double as well
case ScalarType::Float:
log_softmax_wrapper<float>(self, dim, out);
case ScalarType::Float: {
bool success = log_softmax_wrapper<float>(self, dim, out);
ET_KERNEL_CHECK(context, success, InvalidArgument, out);
break;
}
default:
ET_CHECK_MSG(
false,
"Unhandled out dtype %" PRId8,
static_cast<int8_t>(out_scalar_type));
ET_KERNEL_CHECK(context, false, InvalidArgument, out);
}
return out;
}
Expand Down
66 changes: 66 additions & 0 deletions kernels/test/op_log_softmax_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,69 @@ TEST_F(OpLogSoftmaxOutTest, DynamicShapeUnbound) {
Tensor ret = op_log_softmax_out(x, 1, false, out);
EXPECT_TENSOR_CLOSE(out, expected_result);
}

TEST_F(OpLogSoftmaxOutTest, DoubleCase) {
TensorFactory<ScalarType::Double> tf;

// Test case with specific inputs:
// Input tensor: torch.float64 (8, 5, 7)
// Dim: 2
// half_to_float: False
Tensor input = tf.zeros({8, 5, 7});
auto in_data = input.mutable_data_ptr<double>();

// Fill with some test data (sequential values scaled)
for (int i = 0; i < 8 * 5 * 7; i++) {
in_data[i] = static_cast<double>(i) * 0.01;
}

// Output tensor with same shape
Tensor out = tf.zeros({8, 5, 7});

// Apply log_softmax along dimension 2 (the last dimension with size 7)
op_log_softmax_out(input, /*dim=*/2, /*half_to_float=*/false, out);

if (!SupportedFeatures::get()->op_log_softmax_dtype_double) {
// For optimized kernels, we expect the call above to fail gracefully
expect_failure();
GTEST_SKIP() << "This kernel does not support dtype double";
}

// Verify output dimensions
EXPECT_EQ(out.size(0), 8);
EXPECT_EQ(out.size(1), 5);
EXPECT_EQ(out.size(2), 7);

// Verify that output has reasonable values
auto out_data = out.const_data_ptr<double>();

// Check for NaN or Inf values
for (int i = 0; i < 8 * 5 * 7; i++) {
EXPECT_FALSE(std::isnan(out_data[i]))
<< "Output should not contain NaN at index " << i;
EXPECT_FALSE(std::isinf(out_data[i]))
<< "Output should not contain Inf at index " << i;
}

// For log_softmax, all values should be <= 0 (since softmax values are <= 1,
// log is <= 0)
for (int i = 0; i < 8 * 5 * 7; i++) {
EXPECT_LE(out_data[i], 0.0)
<< "Log softmax values should be <= 0 at index " << i;
}

// Verify that each slice along dimension 2 sums to approximately 1 when exp'd
// This tests the core property of softmax: sum(softmax(x)) = 1
for (int batch = 0; batch < 8; batch++) {
for (int channel = 0; channel < 5; channel++) {
double sum_exp = 0.0;
for (int dim2 = 0; dim2 < 7; dim2++) {
int idx = batch * 5 * 7 + channel * 7 + dim2;
sum_exp += std::exp(out_data[idx]);
}
EXPECT_NEAR(sum_exp, 1.0, 1e-6)
<< "Sum of exp(log_softmax) should be 1.0 for batch=" << batch
<< ", channel=" << channel;
}
}
}
Loading