Skip to content

Commit 0ec46e2

Browse files
justinchubyCopilot
andauthored
[torchlib] Implement signbit (#2754)
Signed-off-by: Justin Chu <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent a571309 commit 0ec46e2

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8755,17 +8755,25 @@ def aten_sigmoid(self: TFloat) -> TFloat:
87558755
return op.Sigmoid(self)
87568756

87578757

8758-
@torch_op("aten::sign")
8759-
def aten_sign(self: TReal) -> TReal:
8758+
@torch_op("aten::sign", trace_only=True)
8759+
def aten_sign(self: TensorType) -> TensorType:
87608760
"""sign(Tensor self) -> Tensor"""
87618761

8762+
if self.dtype == ir.DataType.BOOL:
8763+
return op.Identity(self)
8764+
87628765
return op.Sign(self)
87638766

87648767

8765-
def aten_signbit(self: TensorType) -> TensorType:
8768+
@torch_op("aten::signbit", trace_only=True)
8769+
def aten_signbit(self: TensorType) -> BOOL:
87668770
"""signbit(Tensor self) -> Tensor"""
87678771

8768-
raise NotImplementedError()
8772+
if self.dtype == ir.DataType.BOOL:
8773+
return op.ConstantOfShape(op.Shape(self), value=ir.tensor([False]))
8774+
8775+
# -0.0 should return True, but ONNX does not have an appropriate operator to handle it.
8776+
return op.Less(self, op.Constant(value=ir.tensor([0], dtype=self.dtype)))
87698777

87708778

87718779
@torch_op("aten::sin", trace_only=True)

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,7 @@ def _where_input_wrangler(
11691169
TorchLibOpInfo("select_scatter", core_ops.aten_select_scatter),
11701170
TorchLibOpInfo("sigmoid", core_ops.aten_sigmoid),
11711171
TorchLibOpInfo("sign", core_ops.aten_sign),
1172+
TorchLibOpInfo("signbit", core_ops.aten_signbit),
11721173
TorchLibOpInfo("nn.functional.silu", nn_ops.aten_silu),
11731174
TorchLibOpInfo("sin", core_ops.aten_sin),
11741175
TorchLibOpInfo(

0 commit comments

Comments
 (0)