Skip to content

Commit 4ce07c9

Browse files
committed
Refactor
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 98e81a2 commit 4ce07c9

File tree

1 file changed

+11
-21
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+11
-21
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7657,11 +7657,8 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:
76577657
raise NotImplementedError()
76587658

76597659

7660-
@torch_op("aten::remainder.Tensor", trace_only=True)
7661-
def aten_remainder(self: TTensor, other: TTensor) -> TTensor:
7662-
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""
7663-
7664-
if self.dtype.is_integer():
7660+
def _aten_remainder(self: TTensor, other: TTensor, integer: bool) -> TTensor:
7661+
if integer:
76657662
return op.Mod(self, other)
76667663

76677664
# TODO(justinchuby): Improve fp16 precision by following the logic in
@@ -7673,34 +7670,27 @@ def aten_remainder(self: TTensor, other: TTensor) -> TTensor:
76737670
return op.Sub(self, op.Mul(rounded_quotient, other))
76747671

76757672

7673+
@torch_op("aten::remainder.Tensor", trace_only=True)
7674+
def aten_remainder(self: TTensor, other: TTensor) -> TTensor:
7675+
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""
7676+
7677+
return _aten_remainder(self, other, integer=self.dtype.is_integer())
7678+
7679+
76767680
@torch_op("aten::remainder.Scalar", trace_only=True)
76777681
def aten_remainder_scalar(self: TTensor, other: float) -> TTensor:
76787682
"""remainder.Scalar(Tensor self, Scalar other) -> Tensor"""
76797683

76807684
other_tensor = ir.tensor(other, dtype=self.dtype)
7681-
7682-
if self.dtype.is_integer():
7683-
return op.Mod(self, other_tensor)
7684-
7685-
# a - a.div(b, rounding_mode="floor") * b
7686-
rounded_quotient = op.Floor(op.Div(self, other_tensor))
7687-
7688-
return op.Sub(self, op.Mul(rounded_quotient, other_tensor))
7685+
return _aten_remainder(self, other_tensor, integer=self.dtype.is_integer())
76897686

76907687

76917688
@torch_op("aten::remainder.Scalar_Tensor", trace_only=True)
76927689
def aten_remainder_scalar_tensor(self: float, other: TTensor) -> TTensor:
76937690
"""remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
76947691

76957692
self_tensor = ir.tensor(self, dtype=other.dtype)
7696-
7697-
if other.dtype.is_integer():
7698-
return op.Mod(self_tensor, other)
7699-
7700-
# a - a.div(b, rounding_mode="floor") * b
7701-
rounded_quotient = op.Floor(op.Div(self_tensor, other))
7702-
7703-
return op.Sub(self_tensor, op.Mul(rounded_quotient, other))
7693+
return _aten_remainder(self_tensor, other, integer=other.dtype.is_integer())
77047694

77057695

77067696
@torch_op("_operator::mod", trace_only=True)

0 commit comments

Comments
 (0)