Skip to content

Commit 71db148

Browse files
committed
[Relax] Fix HardSigmoid returns 1.0 for NaN input
1 parent fa905d2 commit 71db148

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3255,7 +3255,18 @@ def _impl_v1(cls, bb, inputs, attr, params):
32553255
alpha = relax.const(alpha, dtype=dtype)
32563256
beta = float(attr.get("beta", 0.5))
32573257
beta = relax.const(beta, dtype=dtype)
3258-
return relax.op.clip(relax.op.add(relax.op.multiply(alpha, x), beta), 0, 1)
3258+
clipped = relax.op.clip(relax.op.add(relax.op.multiply(alpha, x), beta), 0, 1)
3259+
3260+
# Preserve NaN values: where x is NaN, return NaN instead of clipped value
3261+
if isinstance(x, relax.Constant):
3262+
x_data = x.data.numpy()
3263+
is_nan_data = _np.isnan(x_data)
3264+
is_nan = relax.const(is_nan_data, dtype="bool")
3265+
else:
3266+
is_nan = relax.op.not_equal(x, x)
3267+
3268+
nan_val = relax.const(_np.nan, dtype=dtype)
3269+
return relax.op.where(is_nan, nan_val, clipped)
32593270

32603271

32613272
class HardSwish(OnnxOpConverter):

tests/python/relax/test_frontend_onnx.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,31 @@ def test_hardsigmoid():
10861086
verify_unary("HardSigmoid", [1, 3, 20, 20], attrs={"alpha": 0.5, "beta": 0.6})
10871087

10881088

1089+
def test_hardsigmoid_nan():
1090+
"""Test that HardSigmoid preserves NaN values in output."""
1091+
test_node = helper.make_node("HardSigmoid", ["x"], ["y"])
1092+
graph = helper.make_graph(
1093+
[test_node],
1094+
"hardsigmoid_nan_test",
1095+
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [3, 4])],
1096+
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 4])],
1097+
)
1098+
1099+
model = helper.make_model(graph, producer_name="hardsigmoid_nan_test")
1100+
1101+
# Create input with NaN values
1102+
input_data = np.array(
1103+
[
1104+
[np.nan, 0.5, -0.5, 1.0],
1105+
[0.0, np.nan, 2.0, -2.0],
1106+
[0.3, 0.7, np.nan, np.nan],
1107+
],
1108+
dtype=np.float32,
1109+
)
1110+
1111+
check_correctness(model, inputs={"x": input_data})
1112+
1113+
10891114
def test_shrink():
10901115
verify_unary("Shrink", [32, 32])
10911116
verify_unary("Shrink", [32, 32], attrs={"lambd": 0.2, "bias": 0.1})

0 commit comments

Comments
 (0)