8989]
9090
9191
92- def _keep_float (f : Callable [..., _T ]) -> Callable [..., sympy .Float ]:
92+ def _keep_float (f : Callable [..., _T ]) -> Callable [..., Union [ _T , sympy .Float ] ]:
9393 @functools .wraps (f )
9494 def inner (* args : Any ) -> Union [_T , sympy .Float ]:
95- r = f (* args )
95+ r : Union [ _T , sympy . Float ] = f (* args )
9696 if any (isinstance (a , sympy .Float ) for a in args ) and not isinstance (
9797 r , sympy .Float
9898 ):
@@ -140,16 +140,16 @@ def integer_factor(expr: sympy.Basic) -> int:
140140 return functools .reduce (math .gcd , integer_factors )
141141
142142 gcd : int = math .gcd (integer_factor (p ), integer_factor (q ))
143- p , q = p / gcd , q / gcd
143+ p , q = p / gcd , q / gcd # type: ignore[operator, assignment] # remove in py3.12
144144
145145 base_splits : List [Tuple [sympy .Basic , ...]] = list (
146146 map (sympy .Mul .make_args , sympy .Add .make_args (p ))
147147 )
148148 divisor_split : Tuple [sympy .Basic , ...] = sympy .Mul .make_args (q )
149149 for x in divisor_split :
150150 if all (x in base_split for base_split in base_splits ):
151- gcd = gcd * x
152- return gcd
151+ gcd = gcd * x # type: ignore[operator] # remove in py3.12
152+ return gcd # type: ignore[return-value] # remove in py3.12
153153
154154
155155# It would be nice to have assertions on whether or not inputs is_integer
@@ -191,15 +191,17 @@ def base(self) -> sympy.Basic:
191191 def divisor (self ) -> sympy .Basic :
192192 return self .args [1 ]
193193
194- def _sympystr (self , printer : sympy .printing .printer . Printer ) -> str :
194+ def _sympystr (self , printer : sympy .printing .StrPrinter ) -> str :
195195 base = printer .parenthesize (self .base , self .precedence )
196196 divisor = printer .parenthesize (self .divisor , self .precedence )
197197 return f"({ base } //{ divisor } )"
198198
199199 # Automatic evaluation.
200200 # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
201201 @classmethod
202- def eval (cls , base : sympy .Basic , divisor : sympy .Basic ) -> Union [sympy .Basic , None ]:
202+ def eval (
203+ cls , base : sympy .Integer , divisor : sympy .Integer
204+ ) -> Union [sympy .Basic , None ]:
203205 # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
204206 # Assert triggered by inequality solver
205207 # assert base.is_integer, base
@@ -281,7 +283,7 @@ class ModularIndexing(sympy.Function):
281283
282284 @classmethod
283285 def eval (
284- cls , base : sympy .Basic , divisor : sympy .Basic , modulus : sympy .Basic
286+ cls , base : sympy .Integer , divisor : sympy .Integer , modulus : sympy .Integer
285287 ) -> Optional [sympy .Basic ]:
286288 if base == 0 or modulus == 1 :
287289 return sympy .Integer (0 )
@@ -306,7 +308,7 @@ def eval(
306308 pass # https://github.com/pytorch/pytorch/issues/108276
307309
308310 if isinstance (base , sympy .Add ):
309- new_terms : List [sympy .Basic ] = []
311+ new_terms : List [sympy .Integer ] = []
310312 all_positive : bool = True
311313 for term in base .args :
312314 if sympy .gcd (term , modulus * divisor ) != modulus * divisor :
@@ -1156,7 +1158,7 @@ class Identity(sympy.Function):
11561158 Prevents expansion and other optimizations
11571159 """
11581160
1159- def __repr__ (self ):
1161+ def __repr__ (self ): # type: ignore[override]
11601162 return f"Identity({ self .args [0 ]} )"
11611163
11621164 def _eval_is_real (self ):
0 commit comments