Skip to content

Commit e9aa08f

Browse files
committed
typ: Add flint.typing module
1 parent ee6e8bc commit e9aa08f

File tree

9 files changed

+316
-48
lines changed

9 files changed

+316
-48
lines changed

src/flint/flint_base/flint_base.pyi

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ class flint_scalar(flint_elem):
3636
def is_zero(self) -> bool: ...
3737
def __pos__(self) -> Self: ...
3838
def __neg__(self) -> Self: ...
39-
def __add__(self, other: Self | int, /) -> Self: ...
39+
def __add__(self, other: int, /) -> Self: ...
4040
def __radd__(self, other: int, /) -> Self: ...
41-
def __sub__(self, other: Self | int, /) -> Self: ...
41+
def __sub__(self, other: int, /) -> Self: ...
4242
def __rsub__(self, other: int, /) -> Self: ...
43-
def __mul__(self, other: Self | int, /) -> Self: ...
43+
def __mul__(self, other: int, /) -> Self: ...
4444
def __rmul__(self, other: int, /) -> Self: ...
45-
def __truediv__(self, other: Self | int, /) -> Self: ...
45+
def __truediv__(self, other: int, /) -> Self: ...
4646
def __rtruediv__(self, other: int, /) -> Self: ...
4747
def __pow__(self, other: int, /) -> Self: ...
4848
def __rpow__(self, other: int, /) -> Self: ...
@@ -90,15 +90,6 @@ class flint_poly(flint_elem, Generic[Telem]):
9090
def complex_roots(self) -> list[Any]: ...
9191
def derivative(self) -> Self: ...
9292

93-
class _flint_poly_exact(flint_poly[Telem]):
94-
def sqrt(self) -> Self: ...
95-
def gcd(self, other: Self | Telem, /) -> Self: ...
96-
def xgcd(self, other: Self | Telem, /) -> tuple[Self, Self, Self]: ...
97-
def factor(self) -> tuple[Telem, list[tuple[Self, int]]]: ...
98-
def factor_squarefree(self) -> tuple[Telem, list[tuple[Self, int]]]: ...
99-
def resultant(self, other: Self | Telem, /) -> Telem: ...
100-
def deflation(self) -> tuple[Self, int]: ...
101-
10293
class Ordering(enum.Enum):
10394
lex = "lex"
10495
deglex = "deglex"

src/flint/meson.build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ thisdir = 'flint'
22

33
pyfiles = [
44
'__init__.py',
5+
'typing.py',
56
]
67

78
exts = [

src/flint/test/test_all.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import random
99

1010
import flint
11+
import flint.typing as typ
1112
import flint.flint_base.flint_base as flint_base
1213
from flint.utils.flint_exceptions import DomainError, IncompatibleContextError
1314

@@ -26,13 +27,10 @@ def raises(f, exception) -> bool:
2627

2728
if TYPE_CHECKING:
2829
from typing import TypeIs
29-
from flint.flint_base.flint_base import _flint_poly_exact
3030

3131

3232
Tscalar = TypeVar('Tscalar', bound=flint_base.flint_scalar)
3333
Tscalar_co = TypeVar('Tscalar_co', bound=flint_base.flint_scalar, covariant=True)
34-
Tscalar_contra = TypeVar('Tscalar_contra', bound=flint_base.flint_scalar, contravariant=True)
35-
Tpoly = TypeVar("Tpoly", bound='_flint_poly_exact')
3634
Tmpoly = TypeVar('Tmpoly', bound=flint_base.flint_mpoly)
3735
Tmpolyctx_co = TypeVar('Tmpolyctx_co', bound=flint_base.flint_mpoly_context, covariant=True)
3836

@@ -2621,32 +2619,26 @@ def _all_polys() -> list[tuple[Any, Any, bool, flint.fmpz]]:
26212619
]
26222620

26232621

2624-
class _TPoly(Protocol[Tpoly, Tscalar_contra]):
2625-
def __call__(
2626-
self, x: Sequence[Tscalar_contra | int] | Tpoly | Tscalar_contra | int, /
2627-
) -> Tpoly: ...
2628-
2629-
2630-
class _Telem(Protocol[Tscalar]):
2631-
def __call__(self, x: int | Tscalar, /) -> Tscalar: ...
2632-
2633-
2634-
_PolyTestCase = tuple[_TPoly[Tpoly, Tscalar], _Telem[Tscalar], bool, flint.fmpz]
2622+
Tpoly = TypeVar("Tpoly", bound=typ.epoly_p)
2623+
Tc = TypeVar("Tc", bound=flint_base.flint_scalar)
2624+
TS = Callable[[Tc | int], Tc]
2625+
TP = Callable[[Tpoly | Sequence[Tc | int] | Tc | int], Tpoly]
2626+
_PolyTestCase = tuple[TP[Tpoly,Tc], TS[Tc], bool, flint.fmpz]
26352627

26362628

26372629
def _for_all_polys(test: Callable[[_PolyTestCase], None]) -> None:
26382630
"""Test all mpoly types with the given test function."""
26392631
# Spell it out like this so that a type checker can understand the types
26402632
# in the generics for each call of test().
26412633

2642-
fmpz: _Telem[flint.fmpz] = flint.fmpz
2643-
fmpq: _Telem[flint.fmpq] = flint.fmpq
2644-
fmpz_poly: _TPoly[flint.fmpz_poly, flint.fmpz] = flint.fmpz_poly
2645-
fmpq_poly: _TPoly[flint.fmpq_poly, flint.fmpq] = flint.fmpq_poly
2634+
fmpz: TS[flint.fmpz] = flint.fmpz
2635+
fmpq: TS[flint.fmpq] = flint.fmpq
2636+
fmpz_poly: TP[flint.fmpz_poly, flint.fmpz] = flint.fmpz_poly
2637+
fmpq_poly: TP[flint.fmpq_poly, flint.fmpq] = flint.fmpq_poly
26462638

26472639
def nmod_poly(
26482640
p: int,
2649-
) -> tuple[_TPoly[flint.nmod_poly, flint.nmod], _Telem[flint.nmod]]:
2641+
) -> tuple[TP[flint.nmod_poly, flint.nmod], TS[flint.nmod]]:
26502642
"""Make nmod poly and scalar constructors for modulus p."""
26512643

26522644
def poly(
@@ -2661,7 +2653,7 @@ def elem(x: int | flint.nmod = 0, /) -> flint.nmod:
26612653

26622654
def fmpz_mod_poly(
26632655
p: int,
2664-
) -> tuple[_TPoly[flint.fmpz_mod_poly, flint.fmpz_mod], _Telem[flint.fmpz_mod]]:
2656+
) -> tuple[TP[flint.fmpz_mod_poly, flint.fmpz_mod], TS[flint.fmpz_mod]]:
26652657
"""Make fmpz_mod poly and scalar constructors for modulus p."""
26662658
ectx = flint.fmpz_mod_ctx(p)
26672659
pctx = flint.fmpz_mod_poly_ctx(ectx)
@@ -2683,7 +2675,7 @@ def elem(x: int | flint.fmpz_mod = 0, /) -> flint.fmpz_mod:
26832675
def fq_default_poly(
26842676
p: int, k: int | None = None
26852677
) -> tuple[
2686-
_TPoly[flint.fq_default_poly, flint.fq_default], _Telem[flint.fq_default]
2678+
TP[flint.fq_default_poly, flint.fq_default], TS[flint.fq_default]
26872679
]:
26882680
"""Make fq_default poly and scalar constructors for field p^k."""
26892681
if k is None:
@@ -2740,7 +2732,7 @@ def wrapper():
27402732

27412733

27422734
@all_polys
2743-
def test_polys(args: _PolyTestCase[Tpoly, Tscalar]) -> None:
2735+
def test_polys(args: _PolyTestCase[typ.epoly_p[Tc], Tc]) -> None:
27442736
# To test type annotations, uncomment:
27452737
# P: type[flint.fmpq_poly]
27462738
# S: type[flint.fmpq]
@@ -2872,7 +2864,7 @@ def setbad(obj, i, val):
28722864

28732865
assert P([1, 2, 3]) + P([4, 5, 6]) == P([5, 7, 9])
28742866

2875-
for T in [int, S, flint.fmpz]:
2867+
for T in (int, S, flint.fmpz):
28762868
assert P([1, 2, 3]) + T(1) == P([2, 2, 3])
28772869
assert T(1) + P([1, 2, 3]) == P([2, 2, 3])
28782870

@@ -2881,7 +2873,7 @@ def setbad(obj, i, val):
28812873

28822874
assert P([1, 2, 3]) - P([4, 5, 6]) == P([-3, -3, -3])
28832875

2884-
for T in [int, S, flint.fmpz]:
2876+
for T in (int, S, flint.fmpz):
28852877
assert P([1, 2, 3]) - T(1) == P([0, 2, 3])
28862878
assert T(1) - P([1, 2, 3]) == P([0, -2, -3])
28872879

@@ -2890,7 +2882,7 @@ def setbad(obj, i, val):
28902882

28912883
assert P([1, 2, 3]) * P([4, 5, 6]) == P([4, 13, 28, 27, 18])
28922884

2893-
for T in [int, S, flint.fmpz]:
2885+
for T in (int, S, flint.fmpz):
28942886
assert P([1, 2, 3]) * T(2) == P([2, 4, 6])
28952887
assert T(2) * P([1, 2, 3]) == P([2, 4, 6])
28962888

src/flint/types/fmpq_poly.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import overload, Any, Sequence
2-
from flint.flint_base.flint_base import _flint_poly_exact
2+
from flint.flint_base.flint_base import flint_poly
33
from flint.types.fmpz import fmpz, ifmpz
44
from flint.types.fmpq import fmpq, ifmpq
55
from flint.types.fmpz_poly import fmpz_poly, ifmpz_poly
@@ -8,7 +8,7 @@ from flint.types.fmpz_poly import fmpz_poly, ifmpz_poly
88
ifmpq_poly = fmpq_poly | ifmpq | ifmpz_poly
99

1010

11-
class fmpq_poly(_flint_poly_exact[fmpq]):
11+
class fmpq_poly(flint_poly[fmpq]):
1212
"""
1313
The *fmpq_poly* type represents dense univariate polynomials
1414
over the rational numbers. For efficiency reasons, an *fmpq_poly* is

src/flint/types/fmpz_mod_poly.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Literal, Sequence, overload
22

3-
from flint.flint_base.flint_base import _flint_poly_exact
3+
from flint.flint_base.flint_base import flint_poly
44
from flint.types.fmpz import fmpz
55
from flint.types.fmpz_poly import fmpz_poly
66
from flint.types.fmpz_mod import fmpz_mod, fmpz_mod_ctx, ifmpz, ifmpz_mod
@@ -32,7 +32,7 @@ class fmpz_mod_poly_ctx:
3232
def __call__(self, val: ifmpz_mod_poly | list[ifmpz_mod]) -> fmpz_mod_poly: ...
3333
def minpoly(self, vals: Sequence[ifmpz_mod]) -> fmpz_mod_poly: ...
3434

35-
class fmpz_mod_poly(_flint_poly_exact[fmpz_mod]):
35+
class fmpz_mod_poly(flint_poly[fmpz_mod]):
3636
"""
3737
The *fmpz_mod_poly* type represents univariate polynomials
3838
over integer modulo an arbitrary-size modulus.

src/flint/types/fmpz_poly.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from typing import overload, Any, Sequence
2-
from flint.flint_base.flint_base import _flint_poly_exact
2+
from flint.flint_base.flint_base import flint_poly
33
from flint.types.fmpz import fmpz, ifmpz
44
from flint.types.fmpq import fmpq
55
from flint.types.fmpq_poly import fmpq_poly
66

77
ifmpz_poly = fmpz_poly | ifmpz
88

9-
class fmpz_poly(_flint_poly_exact[fmpz]):
9+
class fmpz_poly(flint_poly[fmpz]):
1010
"""
1111
The *fmpz_poly* type represents dense univariate polynomials over
1212
the integers.

src/flint/types/fq_default_poly.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import overload, Sequence
2-
from flint.flint_base.flint_base import _flint_poly_exact
2+
from flint.flint_base.flint_base import flint_poly
33
from .fmpz import fmpz, ifmpz
44
from .fmpz_mod import fmpz_mod
55
from .fmpz_poly import fmpz_poly, ifmpz_poly
@@ -44,7 +44,7 @@ class fq_default_poly_ctx:
4444
def __repr__(self) -> str: ...
4545
def __call__(self, val: ifq_default_poly) -> fq_default_poly: ...
4646

47-
class fq_default_poly(_flint_poly_exact[fq_default]):
47+
class fq_default_poly(flint_poly[fq_default]):
4848
"""
4949
The *fq_default_poly* type represents univariate polynomials
5050
over a finite field.

src/flint/types/nmod_poly.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from typing import overload, Iterator, Sequence
2-
from flint.flint_base.flint_base import _flint_poly_exact
2+
from flint.flint_base.flint_base import flint_poly
33
from flint.types.nmod import inmod, nmod
44
from flint.types.fmpz_poly import fmpz_poly
55

66
inmod_poly = nmod_poly | fmpz_poly | inmod
77

8-
class nmod_poly(_flint_poly_exact[nmod]):
8+
class nmod_poly(flint_poly[nmod]):
99
"""Dense univariate polynomials over Z/nZ for word-size n."""
1010

1111
@overload

0 commit comments

Comments
 (0)