11from __future__ import annotations
2- from typing import Any , Callable , TypeVar , Iterable , Protocol
2+ from typing import Any , Callable , TypeVar , Iterable , Protocol , TYPE_CHECKING
33
44import math
55import operator
66import pickle
77import platform
88import random
9- from functools import wraps
109
1110import flint
1211import flint .flint_base .flint_base as flint_base
@@ -25,9 +24,12 @@ def raises(f, exception):
2524 return False
2625
2726
28- Tscalar = TypeVar ('Tscalar' , bound = flint_base .flint_scalar )
29- Tmpoly = TypeVar ('Tmpoly' , bound = flint_base .flint_mpoly )
30- Tmpolyctx_co = TypeVar ('Tmpolyctx_co' , bound = flint_base .flint_mpoly_context , covariant = True )
27+ if TYPE_CHECKING :
28+ from typing import TypeIs
29+ Tscalar = TypeVar ('Tscalar' , bound = flint_base .flint_scalar )
30+ Tscalar_co = TypeVar ('Tscalar_co' , bound = flint_base .flint_scalar , covariant = True )
31+ Tmpoly = TypeVar ('Tmpoly' , bound = flint_base .flint_mpoly )
32+ Tmpolyctx_co = TypeVar ('Tmpolyctx_co' , bound = flint_base .flint_mpoly_context , covariant = True )
3133
3234
3335_default_ctx_string = """\
@@ -2943,6 +2945,15 @@ def __call__(self,
29432945]
29442946
29452947
2948+ class _Q (Protocol [Tscalar_co ]):
2949+ def __call__ (self , a : int , b : int | None = None , / ) -> Tscalar_co :
2950+ ...
2951+
2952+
2953+ def _is_Q (typ : object ) -> TypeIs [_Q ]:
2954+ return typ is flint .fmpq
2955+
2956+
29462957def _for_all_mpolys (test : Callable [[_MPolyTestCase ], None ]) -> None :
29472958 """Test all mpoly types with the given test function."""
29482959 # Spell it out like this so that a type checker can understand the types
@@ -3001,7 +3012,7 @@ def wrapper():
30013012
30023013@all_mpolys
30033014def test_mpolys_constructor (args : _MPolyTestCase [Tmpoly , Tscalar ]) -> None :
3004- P , get_context , S , is_field , characteristic = args
3015+ P , get_context , S , _ , _ = args
30053016
30063017 ctx = get_context (("x" , 2 ))
30073018
@@ -3487,7 +3498,7 @@ def quick_poly():
34873498 assert raises (lambda : p .derivative (None ), TypeError ) # type: ignore
34883499
34893500 if isinstance (p , (flint .fmpz_mpoly , flint .fmpq_mpoly )):
3490- if isinstance (p , flint .fmpq_mpoly ):
3501+ if isinstance (p , flint .fmpq_mpoly ) and _is_Q ( S ) :
34913502 assert p .integral (0 ) == p .integral ("x0" ) == \
34923503 mpoly ({(3 , 2 ): S (4 , 3 ), (2 , 0 ): S (3 , 2 ), (1 , 1 ): S (2 ), (1 , 0 ): S (1 )})
34933504 assert p .integral (1 ) == p .integral ("x1" ) == \
@@ -4908,7 +4919,7 @@ def test_python_threads():
49084919 from threading import Thread
49094920
49104921 iterations = 10 ** 5
4911- threads = 3 + 1
4922+ nthreads = 3 + 1
49124923 size = 3
49134924 M = flint .fmpz_mat ([[0 ]* size for _ in range (size )])
49144925
@@ -4927,7 +4938,7 @@ def get_dets():
49274938 for _ in range (iterations ):
49284939 M .det ()
49294940
4930- threads = [Thread (target = set_values ) for _ in range (threads - 1 )]
4941+ threads = [Thread (target = set_values ) for _ in range (nthreads - 1 )]
49314942 threads .append (Thread (target = get_dets ))
49324943
49334944 for t in threads :
0 commit comments