88import random
99
1010import flint
11+ import flint .typing as typ
1112import flint .flint_base .flint_base as flint_base
1213from flint .utils .flint_exceptions import DomainError , IncompatibleContextError
1314
@@ -26,13 +27,10 @@ def raises(f, exception) -> bool:
2627
2728if TYPE_CHECKING :
2829 from typing import TypeIs
29- from flint .flint_base .flint_base import _flint_poly_exact
3030
3131
3232Tscalar = TypeVar ('Tscalar' , bound = flint_base .flint_scalar )
3333Tscalar_co = TypeVar ('Tscalar_co' , bound = flint_base .flint_scalar , covariant = True )
34- Tscalar_contra = TypeVar ('Tscalar_contra' , bound = flint_base .flint_scalar , contravariant = True )
35- Tpoly = TypeVar ("Tpoly" , bound = '_flint_poly_exact' )
3634Tmpoly = TypeVar ('Tmpoly' , bound = flint_base .flint_mpoly )
3735Tmpolyctx_co = TypeVar ('Tmpolyctx_co' , bound = flint_base .flint_mpoly_context , covariant = True )
3836
@@ -2621,32 +2619,26 @@ def _all_polys() -> list[tuple[Any, Any, bool, flint.fmpz]]:
26212619 ]
26222620
26232621
2624- class _TPoly (Protocol [Tpoly , Tscalar_contra ]):
2625- def __call__ (
2626- self , x : Sequence [Tscalar_contra | int ] | Tpoly | Tscalar_contra | int , /
2627- ) -> Tpoly : ...
2628-
2629-
2630- class _Telem (Protocol [Tscalar ]):
2631- def __call__ (self , x : int | Tscalar , / ) -> Tscalar : ...
2632-
2633-
2634- _PolyTestCase = tuple [_TPoly [Tpoly , Tscalar ], _Telem [Tscalar ], bool , flint .fmpz ]
2622+ Tpoly = TypeVar ("Tpoly" , bound = typ .epoly_p )
2623+ Tc = TypeVar ("Tc" , bound = flint_base .flint_scalar )
2624+ TS = Callable [[Tc | int ], Tc ]
2625+ TP = Callable [[Tpoly | Sequence [Tc | int ] | Tc | int ], Tpoly ]
2626+ _PolyTestCase = tuple [TP [Tpoly ,Tc ], TS [Tc ], bool , flint .fmpz ]
26352627
26362628
26372629def _for_all_polys (test : Callable [[_PolyTestCase ], None ]) -> None :
26382630 """Test all mpoly types with the given test function."""
26392631 # Spell it out like this so that a type checker can understand the types
26402632 # in the generics for each call of test().
26412633
2642- fmpz : _Telem [flint .fmpz ] = flint .fmpz
2643- fmpq : _Telem [flint .fmpq ] = flint .fmpq
2644- fmpz_poly : _TPoly [flint .fmpz_poly , flint .fmpz ] = flint .fmpz_poly
2645- fmpq_poly : _TPoly [flint .fmpq_poly , flint .fmpq ] = flint .fmpq_poly
2634+ fmpz : TS [flint .fmpz ] = flint .fmpz
2635+ fmpq : TS [flint .fmpq ] = flint .fmpq
2636+ fmpz_poly : TP [flint .fmpz_poly , flint .fmpz ] = flint .fmpz_poly
2637+ fmpq_poly : TP [flint .fmpq_poly , flint .fmpq ] = flint .fmpq_poly
26462638
26472639 def nmod_poly (
26482640 p : int ,
2649- ) -> tuple [_TPoly [flint .nmod_poly , flint .nmod ], _Telem [flint .nmod ]]:
2641+ ) -> tuple [TP [flint .nmod_poly , flint .nmod ], TS [flint .nmod ]]:
26502642 """Make nmod poly and scalar constructors for modulus p."""
26512643
26522644 def poly (
@@ -2661,7 +2653,7 @@ def elem(x: int | flint.nmod = 0, /) -> flint.nmod:
26612653
26622654 def fmpz_mod_poly (
26632655 p : int ,
2664- ) -> tuple [_TPoly [flint .fmpz_mod_poly , flint .fmpz_mod ], _Telem [flint .fmpz_mod ]]:
2656+ ) -> tuple [TP [flint .fmpz_mod_poly , flint .fmpz_mod ], TS [flint .fmpz_mod ]]:
26652657 """Make fmpz_mod poly and scalar constructors for modulus p."""
26662658 ectx = flint .fmpz_mod_ctx (p )
26672659 pctx = flint .fmpz_mod_poly_ctx (ectx )
@@ -2683,7 +2675,7 @@ def elem(x: int | flint.fmpz_mod = 0, /) -> flint.fmpz_mod:
26832675 def fq_default_poly (
26842676 p : int , k : int | None = None
26852677 ) -> tuple [
2686- _TPoly [flint .fq_default_poly , flint .fq_default ], _Telem [flint .fq_default ]
2678+ TP [flint .fq_default_poly , flint .fq_default ], TS [flint .fq_default ]
26872679 ]:
26882680 """Make fq_default poly and scalar constructors for field p^k."""
26892681 if k is None :
@@ -2740,7 +2732,7 @@ def wrapper():
27402732
27412733
27422734@all_polys
2743- def test_polys (args : _PolyTestCase [Tpoly , Tscalar ]) -> None :
2735+ def test_polys (args : _PolyTestCase [typ . epoly_p [ Tc ], Tc ]) -> None :
27442736 # To test type annotations, uncomment:
27452737 # P: type[flint.fmpq_poly]
27462738 # S: type[flint.fmpq]
@@ -2872,7 +2864,7 @@ def setbad(obj, i, val):
28722864
28732865 assert P ([1 , 2 , 3 ]) + P ([4 , 5 , 6 ]) == P ([5 , 7 , 9 ])
28742866
2875- for T in [ int , S , flint .fmpz ] :
2867+ for T in ( int , S , flint .fmpz ) :
28762868 assert P ([1 , 2 , 3 ]) + T (1 ) == P ([2 , 2 , 3 ])
28772869 assert T (1 ) + P ([1 , 2 , 3 ]) == P ([2 , 2 , 3 ])
28782870
@@ -2881,7 +2873,7 @@ def setbad(obj, i, val):
28812873
28822874 assert P ([1 , 2 , 3 ]) - P ([4 , 5 , 6 ]) == P ([- 3 , - 3 , - 3 ])
28832875
2884- for T in [ int , S , flint .fmpz ] :
2876+ for T in ( int , S , flint .fmpz ) :
28852877 assert P ([1 , 2 , 3 ]) - T (1 ) == P ([0 , 2 , 3 ])
28862878 assert T (1 ) - P ([1 , 2 , 3 ]) == P ([0 , - 2 , - 3 ])
28872879
@@ -2890,7 +2882,7 @@ def setbad(obj, i, val):
28902882
28912883 assert P ([1 , 2 , 3 ]) * P ([4 , 5 , 6 ]) == P ([4 , 13 , 28 , 27 , 18 ])
28922884
2893- for T in [ int , S , flint .fmpz ] :
2885+ for T in ( int , S , flint .fmpz ) :
28942886 assert P ([1 , 2 , 3 ]) * T (2 ) == P ([2 , 4 , 6 ])
28952887 assert T (2 ) * P ([1 , 2 , 3 ]) == P ([2 , 4 , 6 ])
28962888
0 commit comments