Skip to content

Commit 351b170

Browse files
committed
[Relax] Fix HardSigmoid returns 1.0 for NaN input
1 parent fed71ef commit 351b170

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3346,11 +3346,15 @@ class HardSigmoid(OnnxOpConverter):
33463346
def _impl_v1(cls, bb, inputs, attr, params):
33473347
x = inputs[0]
33483348
dtype = x.struct_info.dtype
3349-
alpha = float(attr.get("alpha", 0.2))
3350-
alpha = relax.const(alpha, dtype=dtype)
3351-
beta = float(attr.get("beta", 0.5))
3352-
beta = relax.const(beta, dtype=dtype)
3353-
return relax.op.clip(relax.op.add(relax.op.multiply(alpha, x), beta), 0, 1)
3349+
alpha = float(attr.get("alpha", 0.2), dtype)
3350+
beta = float(attr.get("beta", 0.5), dtype)
3351+
3352+
is_nan = bb.emit_te(topi.isnan, x)
3353+
transformed = relax.op.add(relax.op.multiply(alpha, x), beta)
3354+
clamped = bb.emit_te(topi.maximum, transformed, 0.0)
3355+
clamped = bb.emit_te(topi.minimum, clamped, 1.0)
3356+
3357+
return bb.emit_te(topi.where, is_nan, x, clamped)
33543358

33553359

33563360
class HardSwish(OnnxOpConverter):

tests/python/relax/test_frontend_onnx.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,31 @@ def test_hardsigmoid():
10891089
verify_unary("HardSigmoid", [1, 3, 20, 20], attrs={"alpha": 0.5, "beta": 0.6})
10901090

10911091

1092+
def test_hardsigmoid_nan():
1093+
"""Test that HardSigmoid preserves NaN values in output."""
1094+
test_node = helper.make_node("HardSigmoid", ["x"], ["y"])
1095+
graph = helper.make_graph(
1096+
[test_node],
1097+
"hardsigmoid_nan_test",
1098+
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [3, 4])],
1099+
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 4])],
1100+
)
1101+
1102+
model = helper.make_model(graph, producer_name="hardsigmoid_nan_test")
1103+
1104+
# Create input with NaN values
1105+
input_data = np.array(
1106+
[
1107+
[np.nan, 0.5, -0.5, 1.0],
1108+
[0.0, np.nan, 2.0, -2.0],
1109+
[0.3, 0.7, np.nan, np.nan],
1110+
],
1111+
dtype=np.float32,
1112+
)
1113+
1114+
check_correctness(model, inputs={"x": input_data})
1115+
1116+
10921117
def test_shrink():
10931118
verify_unary("Shrink", [32, 32])
10941119
verify_unary("Shrink", [32, 32], attrs={"lambd": 0.2, "bias": 0.1})

0 commit comments

Comments
 (0)