Skip to content

Commit d520798

Browse files
[WebGPU EP] Fix NaN bug in softmax operator (microsoft#24855)
Handle NaN in softmax operator for WebGPU EP and JSEP. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent a612118 commit d520798

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

js/web/lib/wasm/jsep/webgpu/ops/softmax.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAt
152152
153153
// calculate final value for each element in the row
154154
for (var col = lindex; col < cols; col += wg) {
155-
let value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared;
155+
var value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared;
156+
// max operation protects against NaN since all values should be >=0
157+
value = max(value, ${valueType}(0.0));
156158
setValue(row, col, row_stride, value);
157159
}
158160
}`;

onnxruntime/core/providers/webgpu/math/softmax.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
141141

142142
// Calculate the final value for each element in the row
143143
<< " for (var col = lindex; col < cols; col += wg) {\n"
144-
<< " let value = exp(getValue(row, col, row_stride) - row_max_shared) / row_sum_shared;\n"
144+
<< " var value = exp(getValue(row, col, row_stride) - row_max_shared) / row_sum_shared;\n"
145+
<< " // max operation protects against NaN since all values should be >=0\n"
146+
<< " value = max(value, x_value_t(0.0));\n"
145147
<< " setValue(row, col, row_stride, value);\n"
146148
<< " }\n";
147149

onnxruntime/test/providers/cpu/math/softmax_test.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,22 @@ TEST(SoftmaxOperator, Simple) {
4949
RunTest(x_vals, expected_vals, dimensions);
5050
}
5151

52+
#ifdef USE_WEBGPU
53+
TEST(SoftmaxOperator, webgpu_nan) {
54+
OpTester test("Softmax", 13); // axis default is -1
55+
56+
std::vector<float> x_vals = {-INFINITY, -INFINITY, -INFINITY};
57+
std::vector<float> expected_result = {0.0f, 0.0f, 0.0f};
58+
std::vector<int64_t> dimensions = {1, 3};
59+
60+
test.AddInput<float>("X", dimensions, x_vals);
61+
test.AddOutput<float>("Y", dimensions, expected_result);
62+
63+
// explicitly disable CPU EP for this test since CPU implementation does not handle NaN
64+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCpuExecutionProvider});
65+
}
66+
#endif
67+
5268
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_XNNPACK)
5369
TEST(SoftmaxOperator, Simple_fp16) {
5470
#ifdef USE_CUDA

0 commit comments

Comments
 (0)