Skip to content

Commit 34f1b56

Browse files
committed
typ: fix up some type annotations for mypy
1 parent 1a429b0 commit 34f1b56

File tree

9 files changed

+90
-73
lines changed

9 files changed

+90
-73
lines changed

src/flint/flint_base/flint_base.pyi

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ Tctx = TypeVar('Tctx', bound=flint_mpoly_context)
1717

1818
Sctx = TypeVar('Sctx', bound=flint_mpoly_context)
1919

20+
_str = str
21+
2022

2123
class flint_elem:
2224
pass
@@ -50,8 +52,8 @@ class flint_mpoly(flint_elem, Generic[Tctx, Telem, Telem_coerce]):
5052
ctx: Tctx | None = None
5153
) -> None: ...
5254

53-
def str(self) -> str: ...
54-
def repr(self) -> str: ...
55+
def str(self) -> _str: ...
56+
def repr(self) -> _str: ...
5557

5658
def context(self) -> Tctx: ...
5759

@@ -108,8 +110,8 @@ class flint_mpoly(flint_elem, Generic[Tctx, Telem, Telem_coerce]):
108110

109111
def sqrt(self) -> Self: ...
110112

111-
def resultant(self, other: Self, var: str | int) -> Self: ...
112-
def discriminant(self, var: str | int) -> Self: ...
113+
def resultant(self, other: Self, var: _str | int) -> Self: ...
114+
def discriminant(self, var: _str | int) -> Self: ...
113115

114116
def deflation_index(self) -> tuple[list[int], list[int]]: ...
115117
def deflation(self) -> tuple[Self, list[int]]: ...
@@ -118,18 +120,18 @@ class flint_mpoly(flint_elem, Generic[Tctx, Telem, Telem_coerce]):
118120
def inflate(self, N: list[int]) -> Self: ...
119121
def deflate(self, N: list[int]) -> Self: ...
120122

121-
def subs(self, mapping: dict[str | int, Telem | Telem_coerce | int]) -> Self: ...
123+
def subs(self, mapping: dict[_str | int, Telem | Telem_coerce | int]) -> Self: ...
122124
def compose(self, *args: Self, ctx: Tctx | None = None) -> Self: ...
123125
def __call__(self, *args: Telem | Telem_coerce) -> Telem: ...
124126

125-
def derivative(self, var: str | int) -> Self: ...
127+
def derivative(self, var: _str | int) -> Self: ...
126128

127-
def unused_gens(self) -> tuple[str, ...]: ...
129+
def unused_gens(self) -> tuple[_str, ...]: ...
128130

129131
def project_to_context(
130132
self,
131133
other_ctx: Tctx,
132-
mapping: dict[str | int, str | int] | None = None
134+
mapping: dict[_str | int, _str | int] | None = None
133135
) -> Self: ...
134136

135137

src/flint/test/test_all.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from __future__ import annotations
2-
from typing import Any, Callable, TypeVar, Iterable, Protocol
2+
from typing import Any, Callable, TypeVar, Iterable, Protocol, TYPE_CHECKING
33

44
import math
55
import operator
66
import pickle
77
import platform
88
import random
9-
from functools import wraps
109

1110
import flint
1211
import flint.flint_base.flint_base as flint_base
@@ -25,9 +24,12 @@ def raises(f, exception):
2524
return False
2625

2726

28-
Tscalar = TypeVar('Tscalar', bound=flint_base.flint_scalar)
29-
Tmpoly = TypeVar('Tmpoly', bound=flint_base.flint_mpoly)
30-
Tmpolyctx_co = TypeVar('Tmpolyctx_co', bound=flint_base.flint_mpoly_context, covariant=True)
27+
if TYPE_CHECKING:
28+
from typing import TypeIs
29+
Tscalar = TypeVar('Tscalar', bound=flint_base.flint_scalar)
30+
Tscalar_co = TypeVar('Tscalar_co', bound=flint_base.flint_scalar, covariant=True)
31+
Tmpoly = TypeVar('Tmpoly', bound=flint_base.flint_mpoly)
32+
Tmpolyctx_co = TypeVar('Tmpolyctx_co', bound=flint_base.flint_mpoly_context, covariant=True)
3133

3234

3335
_default_ctx_string = """\
@@ -2943,6 +2945,15 @@ def __call__(self,
29432945
]
29442946

29452947

2948+
class _Q(Protocol[Tscalar_co]):
2949+
def __call__(self, a: int, b: int | None = None, /) -> Tscalar_co:
2950+
...
2951+
2952+
2953+
def _is_Q(typ: object) -> TypeIs[_Q]:
2954+
return typ is flint.fmpq
2955+
2956+
29462957
def _for_all_mpolys(test: Callable[[_MPolyTestCase], None]) -> None:
29472958
"""Test all mpoly types with the given test function."""
29482959
# Spell it out like this so that a type checker can understand the types
@@ -3001,7 +3012,7 @@ def wrapper():
30013012

30023013
@all_mpolys
30033014
def test_mpolys_constructor(args: _MPolyTestCase[Tmpoly, Tscalar]) -> None:
3004-
P, get_context, S, is_field, characteristic = args
3015+
P, get_context, S, _, _ = args
30053016

30063017
ctx = get_context(("x", 2))
30073018

@@ -3487,7 +3498,7 @@ def quick_poly():
34873498
assert raises(lambda: p.derivative(None), TypeError) # type: ignore
34883499

34893500
if isinstance(p, (flint.fmpz_mpoly, flint.fmpq_mpoly)):
3490-
if isinstance(p, flint.fmpq_mpoly):
3501+
if isinstance(p, flint.fmpq_mpoly) and _is_Q(S):
34913502
assert p.integral(0) == p.integral("x0") == \
34923503
mpoly({(3, 2): S(4, 3), (2, 0): S(3, 2), (1, 1): S(2), (1, 0): S(1)})
34933504
assert p.integral(1) == p.integral("x1") == \
@@ -4908,7 +4919,7 @@ def test_python_threads():
49084919
from threading import Thread
49094920

49104921
iterations = 10**5
4911-
threads = 3 + 1
4922+
nthreads = 3 + 1
49124923
size = 3
49134924
M = flint.fmpz_mat([[0]*size for _ in range(size)])
49144925

@@ -4927,7 +4938,7 @@ def get_dets():
49274938
for _ in range(iterations):
49284939
M.det()
49294940

4930-
threads = [Thread(target=set_values) for _ in range(threads-1)]
4941+
threads = [Thread(target=set_values) for _ in range(nthreads-1)]
49314942
threads.append(Thread(target=get_dets))
49324943

49334944
for t in threads:

src/flint/types/fmpq.pyi

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ from typing import overload
33
from .fmpz import fmpz
44

55

6+
_str = str
67
ifmpz = int | fmpz
78
ifmpq = int | fmpz | fmpq
89

@@ -11,12 +12,10 @@ class fmpq(flint_scalar):
1112
@overload
1213
def __init__(self): ...
1314
@overload
14-
def __init__(self, arg: ifmpq | str, /): ...
15+
def __init__(self, arg: ifmpq | _str, /): ...
1516
@overload
1617
def __init__(self, num: ifmpz, den: ifmpz, /): ...
1718

18-
def __init__(self, arg1: ifmpq | str = 0, arg2: ifmpz = 1, /): ...
19-
2019
@property
2120
def p(self) -> fmpz: ...
2221
@property
@@ -34,11 +33,11 @@ class fmpq(flint_scalar):
3433

3534
def next(self, signed: bool = True, minimal: bool = True) -> fmpq: ...
3635

37-
def str(self, base: int = 10, condense: int = 0) -> str: ...
38-
def repr(self) -> str: ...
36+
def str(self, base: int = 10, condense: int = 0) -> _str: ...
37+
def repr(self) -> _str: ...
3938

40-
def __str__(self) -> str: ...
41-
def __repr__(self) -> str: ...
39+
def __str__(self) -> _str: ...
40+
def __repr__(self) -> _str: ...
4241

4342
def __int__(self) -> int: ...
4443

@@ -54,8 +53,6 @@ class fmpq(flint_scalar):
5453
@overload
5554
def __round__(self) -> fmpz: ...
5655

57-
def __round__(self, ndigits: int | None = None) -> fmpq | fmpz: ...
58-
5956
def __lt__(self, other: ifmpq, /) -> bool: ...
6057
def __le__(self, other: ifmpq, /) -> bool: ...
6158
def __gt__(self, other: ifmpq, /) -> bool: ...

src/flint/types/fmpq_mpoly.pyi

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@ from .fmpq import fmpq
55
from .fmpz_mpoly import fmpz_mpoly
66

77

8+
_str = str
89
ifmpz = int | fmpz
910
ifmpq = int | fmpz | fmpq
1011

1112

1213
class fmpq_mpoly_ctx(flint_mpoly_context[fmpq_mpoly, fmpq, ifmpq]):
1314
@classmethod
1415
def get(cls,
15-
names: str | Iterable[str | tuple[str, int]] | tuple[str, int],
16-
ordering: Ordering | str = Ordering.lex
16+
names: _str | Iterable[_str | tuple[_str, int]] | tuple[_str, int],
17+
ordering: Ordering | _str = Ordering.lex
1718
) -> fmpq_mpoly_ctx:
1819
...
1920

@@ -27,12 +28,12 @@ class fmpq_mpoly_ctx(flint_mpoly_context[fmpq_mpoly, fmpq, ifmpq]):
2728

2829
class fmpq_mpoly(flint_mpoly[fmpq_mpoly_ctx, fmpq, ifmpq]):
2930
def __init__(self,
30-
val: fmpq_mpoly | fmpz_mpoly | ifmpq | dict[tuple[int, ...], ifmpq] | str = 0,
31+
val: fmpq_mpoly | fmpz_mpoly | ifmpq | dict[tuple[int, ...], ifmpq] | _str = 0,
3132
ctx: fmpq_mpoly_ctx | None = None
3233
) -> None: ...
3334

34-
def str(self) -> str: ...
35-
def repr(self) -> str: ...
35+
def str(self) -> _str: ...
36+
def repr(self) -> _str: ...
3637

3738
def context(self) -> fmpq_mpoly_ctx: ...
3839

@@ -52,7 +53,7 @@ class fmpq_mpoly(flint_mpoly[fmpq_mpoly_ctx, fmpq, ifmpq]):
5253
def __getitem__(self, index: tuple[int, ...]) -> fmpq: ...
5354
def __setitem__(self, index: tuple[int, ...], coeff: ifmpq) -> None: ...
5455

55-
def subs(self, mapping: dict[str | int, ifmpq]) -> fmpq_mpoly: ...
56+
def subs(self, mapping: dict[_str | int, ifmpq]) -> fmpq_mpoly: ...
5657
def compose(self, *args: fmpq_mpoly, ctx: fmpq_mpoly_ctx | None = None) -> fmpq_mpoly: ...
5758

5859
def __call__(self, *args: ifmpq) -> fmpq: ...
@@ -89,17 +90,17 @@ class fmpq_mpoly(flint_mpoly[fmpq_mpoly_ctx, fmpq, ifmpq]):
8990

9091
def sqrt(self, assume_perfect_square: bool = False) -> fmpq_mpoly: ...
9192

92-
def resultant(self, other: fmpq_mpoly, var: str | int) -> fmpq_mpoly: ...
93-
def discriminant(self, var: str | int) -> fmpq_mpoly: ...
93+
def resultant(self, other: fmpq_mpoly, var: _str | int) -> fmpq_mpoly: ...
94+
def discriminant(self, var: _str | int) -> fmpq_mpoly: ...
9495

9596
def deflation(self) -> tuple[fmpq_mpoly, list[int]]: ...
9697
def inflate(self, N: list[int]) -> fmpq_mpoly: ...
9798
def deflate(self, N: list[int]) -> fmpq_mpoly: ...
9899
def deflation_monom(self) -> tuple[fmpq_mpoly, list[int], fmpq_mpoly]: ...
99100
def deflation_index(self) -> tuple[list[int], list[int]]: ...
100101

101-
def derivative(self, var: str | int) -> fmpq_mpoly: ...
102-
def integral(self, var: str | int) -> fmpq_mpoly: ...
102+
def derivative(self, var: _str | int) -> fmpq_mpoly: ...
103+
def integral(self, var: _str | int) -> fmpq_mpoly: ...
103104

104105

105106
class fmpq_mpoly_vec:

src/flint/types/fmpz.pyi

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
#
55
from ..flint_base.flint_base import flint_scalar
66

7+
_str = str
8+
79
class fmpz(flint_scalar):
8-
def __init__(self, arg: int | fmpz | str = 0, /): ...
10+
def __init__(self, arg: int | fmpz | _str = 0, /): ...
911

1012
@property
1113
def numerator(self) -> fmpz: ...
@@ -15,11 +17,11 @@ class fmpz(flint_scalar):
1517
def bit_length(self) -> int: ...
1618
def height_bits(self, signed: bool = False) -> int: ...
1719

18-
def str(self, base: int = 10, condense: int = 0) -> str: ...
19-
def repr(self) -> str: ...
20+
def str(self, base: int = 10, condense: int = 0) -> _str: ...
21+
def repr(self) -> _str: ...
2022

21-
def __str__(self) -> str: ...
22-
def __repr__(self) -> str: ...
23+
def __str__(self) -> _str: ...
24+
def __repr__(self) -> _str: ...
2325

2426
def __int__(self) -> int: ...
2527
def __index__(self) -> int: ...

src/flint/types/fmpz_mod_mpoly.pyi

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ from .fmpz_mod import fmpz_mod
55
from .fmpz_mpoly import fmpz_mpoly
66

77

8+
_str = str
89
ifmpz = int | fmpz
910
ifmpz_mod = int | fmpz | fmpz_mod
1011

@@ -13,8 +14,8 @@ class fmpz_mod_mpoly_ctx(flint_mod_mpoly_context[fmpz_mod_mpoly, fmpz_mod, ifmpz
1314

1415
@classmethod
1516
def get(cls,
16-
names: str | Iterable[str | tuple[str, int]] | tuple[str, int],
17-
ordering: Ordering | str = Ordering.lex,
17+
names: _str | Iterable[_str | tuple[_str, int]] | tuple[_str, int],
18+
ordering: Ordering | _str = Ordering.lex,
1819
*,
1920
modulus: ifmpz,
2021
) -> fmpz_mod_mpoly_ctx:
@@ -30,12 +31,12 @@ class fmpz_mod_mpoly_ctx(flint_mod_mpoly_context[fmpz_mod_mpoly, fmpz_mod, ifmpz
3031

3132
class fmpz_mod_mpoly(flint_mpoly[fmpz_mod_mpoly_ctx, fmpz_mod, ifmpz_mod]):
3233
def __init__(self,
33-
val: fmpz_mod_mpoly | fmpz_mpoly | ifmpz_mod | dict[tuple[int, ...], ifmpz_mod] | str = 0,
34+
val: fmpz_mod_mpoly | fmpz_mpoly | ifmpz_mod | dict[tuple[int, ...], ifmpz_mod] | _str = 0,
3435
ctx: fmpz_mod_mpoly_ctx | None = None
3536
) -> None: ...
3637

37-
def str(self) -> str: ...
38-
def repr(self) -> str: ...
38+
def str(self) -> _str: ...
39+
def repr(self) -> _str: ...
3940

4041
def context(self) -> fmpz_mod_mpoly_ctx: ...
4142

@@ -55,7 +56,7 @@ class fmpz_mod_mpoly(flint_mpoly[fmpz_mod_mpoly_ctx, fmpz_mod, ifmpz_mod]):
5556
def __getitem__(self, index: tuple[int, ...]) -> fmpz_mod: ...
5657
def __setitem__(self, index: tuple[int, ...], coeff: ifmpz_mod) -> None: ...
5758

58-
def subs(self, mapping: dict[str | int, ifmpz_mod]) -> fmpz_mod_mpoly: ...
59+
def subs(self, mapping: dict[_str | int, ifmpz_mod]) -> fmpz_mod_mpoly: ...
5960
def compose(self, *args: fmpz_mod_mpoly, ctx: fmpz_mod_mpoly_ctx | None = None) -> fmpz_mod_mpoly: ...
6061

6162
def __call__(self, *args: ifmpz_mod) -> fmpz_mod: ...
@@ -93,16 +94,16 @@ class fmpz_mod_mpoly(flint_mpoly[fmpz_mod_mpoly_ctx, fmpz_mod, ifmpz_mod]):
9394

9495
def sqrt(self, assume_perfect_square: bool = False) -> fmpz_mod_mpoly: ...
9596

96-
def resultant(self, other: fmpz_mod_mpoly, var: str | int) -> fmpz_mod_mpoly: ...
97-
def discriminant(self, var: str | int) -> fmpz_mod_mpoly: ...
97+
def resultant(self, other: fmpz_mod_mpoly, var: _str | int) -> fmpz_mod_mpoly: ...
98+
def discriminant(self, var: _str | int) -> fmpz_mod_mpoly: ...
9899

99100
def deflation(self) -> tuple[fmpz_mod_mpoly, list[int]]: ...
100101
def inflate(self, N: list[int]) -> fmpz_mod_mpoly: ...
101102
def deflate(self, N: list[int]) -> fmpz_mod_mpoly: ...
102103
def deflation_monom(self) -> tuple[fmpz_mod_mpoly, list[int], fmpz_mod_mpoly]: ...
103104
def deflation_index(self) -> tuple[list[int], list[int]]: ...
104105

105-
def derivative(self, var: str | int) -> fmpz_mod_mpoly: ...
106+
def derivative(self, var: _str | int) -> fmpz_mod_mpoly: ...
106107

107108

108109
class fmpz_mod_mpoly_vec:

0 commit comments

Comments
 (0)