File tree Expand file tree Collapse file tree 2 files changed +13
-4
lines changed
onnxscript/function_libs/torch_lib/ops
tests/function_libs/torch_lib Expand file tree Collapse file tree 2 files changed +13
-4
lines changed Original file line number Diff line number Diff line change @@ -8755,17 +8755,25 @@ def aten_sigmoid(self: TFloat) -> TFloat:
87558755 return op .Sigmoid (self )
87568756
87578757
8758- @torch_op ("aten::sign" )
8759- def aten_sign (self : TReal ) -> TReal :
8758+ @torch_op ("aten::sign" , trace_only = True )
8759+ def aten_sign (self : TensorType ) -> TensorType :
87608760 """sign(Tensor self) -> Tensor"""
87618761
8762+ if self .dtype == ir .DataType .BOOL :
8763+ return op .Identity (self )
8764+
87628765 return op .Sign (self )
87638766
87648767
8765- def aten_signbit (self : TensorType ) -> TensorType :
8768+ @torch_op ("aten::signbit" , trace_only = True )
8769+ def aten_signbit (self : TensorType ) -> BOOL :
87668770 """signbit(Tensor self) -> Tensor"""
87678771
8768- raise NotImplementedError ()
8772+ if self .dtype == ir .DataType .BOOL :
8773+ return op .ConstantOfShape (op .Shape (self ), value = ir .tensor ([False ]))
8774+
8775+ # -0.0 should return True, but ONNX does not have an appropriate operator to handle it.
8776+ return op .Less (self , op .Constant (value = ir .tensor ([0 ], dtype = self .dtype )))
87698777
87708778
87718779@torch_op ("aten::sin" , trace_only = True )
Original file line number Diff line number Diff line change @@ -1169,6 +1169,7 @@ def _where_input_wrangler(
11691169 TorchLibOpInfo ("select_scatter" , core_ops .aten_select_scatter ),
11701170 TorchLibOpInfo ("sigmoid" , core_ops .aten_sigmoid ),
11711171 TorchLibOpInfo ("sign" , core_ops .aten_sign ),
1172+ TorchLibOpInfo ("signbit" , core_ops .aten_signbit ),
11721173 TorchLibOpInfo ("nn.functional.silu" , nn_ops .aten_silu ),
11731174 TorchLibOpInfo ("sin" , core_ops .aten_sin ),
11741175 TorchLibOpInfo (
You can’t perform that action at this time.
0 commit comments