@@ -7657,11 +7657,8 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:
76577657 raise NotImplementedError ()
76587658
76597659
7660- @torch_op ("aten::remainder.Tensor" , trace_only = True )
7661- def aten_remainder (self : TTensor , other : TTensor ) -> TTensor :
7662- """remainder.Tensor(Tensor self, Tensor other) -> Tensor"""
7663-
7664- if self .dtype .is_integer ():
7660+ def _aten_remainder (self : TTensor , other : TTensor , integer : bool ) -> TTensor :
7661+ if integer :
76657662 return op .Mod (self , other )
76667663
76677664 # TODO(justinchuby): Improve fp16 precision by following the logic in
@@ -7673,34 +7670,27 @@ def aten_remainder(self: TTensor, other: TTensor) -> TTensor:
76737670 return op .Sub (self , op .Mul (rounded_quotient , other ))
76747671
76757672
7673+ @torch_op ("aten::remainder.Tensor" , trace_only = True )
7674+ def aten_remainder (self : TTensor , other : TTensor ) -> TTensor :
7675+ """remainder.Tensor(Tensor self, Tensor other) -> Tensor"""
7676+
7677+ return _aten_remainder (self , other , integer = self .dtype .is_integer ())
7678+
7679+
76767680@torch_op ("aten::remainder.Scalar" , trace_only = True )
76777681def aten_remainder_scalar (self : TTensor , other : float ) -> TTensor :
76787682 """remainder.Scalar(Tensor self, Scalar other) -> Tensor"""
76797683
76807684 other_tensor = ir .tensor (other , dtype = self .dtype )
7681-
7682- if self .dtype .is_integer ():
7683- return op .Mod (self , other_tensor )
7684-
7685- # a - a.div(b, rounding_mode="floor") * b
7686- rounded_quotient = op .Floor (op .Div (self , other_tensor ))
7687-
7688- return op .Sub (self , op .Mul (rounded_quotient , other_tensor ))
7685+ return _aten_remainder (self , other_tensor , integer = self .dtype .is_integer ())
76897686
76907687
76917688@torch_op ("aten::remainder.Scalar_Tensor" , trace_only = True )
76927689def aten_remainder_scalar_tensor (self : float , other : TTensor ) -> TTensor :
76937690 """remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
76947691
76957692 self_tensor = ir .tensor (self , dtype = other .dtype )
7696-
7697- if other .dtype .is_integer ():
7698- return op .Mod (self_tensor , other )
7699-
7700- # a - a.div(b, rounding_mode="floor") * b
7701- rounded_quotient = op .Floor (op .Div (self_tensor , other ))
7702-
7703- return op .Sub (self_tensor , op .Mul (rounded_quotient , other ))
7693+ return _aten_remainder (self_tensor , other , integer = other .dtype .is_integer ())
77047694
77057695
77067696@torch_op ("_operator::mod" , trace_only = True )
0 commit comments