Skip to content

Commit 7f9c064

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
fix mypi in utils/_sympy/functions.py (pytorch#136339)
Signed-off-by: Bob Ren <[email protected]> Turns out older versions of python, in particular 3.8 shows errors that 3.12 doesn't. For posterity these are the steps I took to reproduce: ``` conda create -n py38 python=3.8 conda activate py38 pip install -r requirements.txt lintrunner init dmypy restart && lintrunner --all-files --take MYPY ``` Pull Request resolved: pytorch#136339 Approved by: https://github.com/Skylion007 ghstack dependencies: pytorch#136205
1 parent f53a0f9 commit 7f9c064

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

torch/utils/_sympy/functions.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@
8989
]
9090

9191

92-
def _keep_float(f: Callable[..., _T]) -> Callable[..., sympy.Float]:
92+
def _keep_float(f: Callable[..., _T]) -> Callable[..., Union[_T, sympy.Float]]:
9393
@functools.wraps(f)
9494
def inner(*args: Any) -> Union[_T, sympy.Float]:
95-
r = f(*args)
95+
r: Union[_T, sympy.Float] = f(*args)
9696
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
9797
r, sympy.Float
9898
):
@@ -140,16 +140,16 @@ def integer_factor(expr: sympy.Basic) -> int:
140140
return functools.reduce(math.gcd, integer_factors)
141141

142142
gcd: int = math.gcd(integer_factor(p), integer_factor(q))
143-
p, q = p / gcd, q / gcd
143+
p, q = p / gcd, q / gcd # type: ignore[operator, assignment] # remove in py3.12
144144

145145
base_splits: List[Tuple[sympy.Basic, ...]] = list(
146146
map(sympy.Mul.make_args, sympy.Add.make_args(p))
147147
)
148148
divisor_split: Tuple[sympy.Basic, ...] = sympy.Mul.make_args(q)
149149
for x in divisor_split:
150150
if all(x in base_split for base_split in base_splits):
151-
gcd = gcd * x
152-
return gcd
151+
gcd = gcd * x # type: ignore[operator] # remove in py3.12
152+
return gcd # type: ignore[return-value] # remove in py3.12
153153

154154

155155
# It would be nice to have assertions on whether or not inputs is_integer
@@ -191,15 +191,17 @@ def base(self) -> sympy.Basic:
191191
def divisor(self) -> sympy.Basic:
192192
return self.args[1]
193193

194-
def _sympystr(self, printer: sympy.printing.printer.Printer) -> str:
194+
def _sympystr(self, printer: sympy.printing.StrPrinter) -> str:
195195
base = printer.parenthesize(self.base, self.precedence)
196196
divisor = printer.parenthesize(self.divisor, self.precedence)
197197
return f"({base}//{divisor})"
198198

199199
# Automatic evaluation.
200200
# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
201201
@classmethod
202-
def eval(cls, base: sympy.Basic, divisor: sympy.Basic) -> Union[sympy.Basic, None]:
202+
def eval(
203+
cls, base: sympy.Integer, divisor: sympy.Integer
204+
) -> Union[sympy.Basic, None]:
203205
# python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
204206
# Assert triggered by inequality solver
205207
# assert base.is_integer, base
@@ -281,7 +283,7 @@ class ModularIndexing(sympy.Function):
281283

282284
@classmethod
283285
def eval(
284-
cls, base: sympy.Basic, divisor: sympy.Basic, modulus: sympy.Basic
286+
cls, base: sympy.Integer, divisor: sympy.Integer, modulus: sympy.Integer
285287
) -> Optional[sympy.Basic]:
286288
if base == 0 or modulus == 1:
287289
return sympy.Integer(0)
@@ -306,7 +308,7 @@ def eval(
306308
pass # https://github.com/pytorch/pytorch/issues/108276
307309

308310
if isinstance(base, sympy.Add):
309-
new_terms: List[sympy.Basic] = []
311+
new_terms: List[sympy.Integer] = []
310312
all_positive: bool = True
311313
for term in base.args:
312314
if sympy.gcd(term, modulus * divisor) != modulus * divisor:
@@ -1156,7 +1158,7 @@ class Identity(sympy.Function):
11561158
Prevents expansion and other optimizations
11571159
"""
11581160

1159-
def __repr__(self):
1161+
def __repr__(self): # type: ignore[override]
11601162
return f"Identity({self.args[0]})"
11611163

11621164
def _eval_is_real(self):

0 commit comments

Comments
 (0)