Skip to content

Commit 2413c4a

Browse files
committed
[Relax] Fix HardSigmoid returns 1.0 for NaN input
1 parent fed71ef commit 2413c4a

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3350,7 +3350,13 @@ def _impl_v1(cls, bb, inputs, attr, params):
33503350
alpha = relax.const(alpha, dtype=dtype)
33513351
beta = float(attr.get("beta", 0.5))
33523352
beta = relax.const(beta, dtype=dtype)
3353-
return relax.op.clip(relax.op.add(relax.op.multiply(alpha, x), beta), 0, 1)
3353+
3354+
is_nan = bb.emit_te(topi.isnan, x)
3355+
transformed = bb.emit(relax.op.add(relax.op.multiply(alpha, x), beta))
3356+
clamped = bb.emit_te(topi.maximum, transformed, 0.0)
3357+
clamped = bb.emit_te(topi.minimum, clamped, 1.0)
3358+
3359+
return bb.emit_te(topi.where, is_nan, x, clamped)
33543360

33553361

33563362
class HardSwish(OnnxOpConverter):

tests/python/relax/test_frontend_onnx.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,31 @@ def test_hardsigmoid():
10891089
verify_unary("HardSigmoid", [1, 3, 20, 20], attrs={"alpha": 0.5, "beta": 0.6})
10901090

10911091

1092+
def test_hardsigmoid_nan():
1093+
"""Test that HardSigmoid preserves NaN values in output."""
1094+
test_node = helper.make_node("HardSigmoid", ["x"], ["y"])
1095+
graph = helper.make_graph(
1096+
[test_node],
1097+
"hardsigmoid_nan_test",
1098+
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [3, 4])],
1099+
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 4])],
1100+
)
1101+
1102+
model = helper.make_model(graph, producer_name="hardsigmoid_nan_test")
1103+
1104+
# Create input with NaN values
1105+
input_data = np.array(
1106+
[
1107+
[np.nan, 0.5, -0.5, 1.0],
1108+
[0.0, np.nan, 2.0, -2.0],
1109+
[0.3, 0.7, np.nan, np.nan],
1110+
],
1111+
dtype=np.float32,
1112+
)
1113+
1114+
check_correctness(model, inputs={"x": input_data})
1115+
1116+
10921117
def test_shrink():
10931118
verify_unary("Shrink", [32, 32])
10941119
verify_unary("Shrink", [32, 32], attrs={"lambd": 0.2, "bias": 0.1})

0 commit comments

Comments
 (0)