Skip to content

Commit 97b5a41

Browse files
committed
Use scalar_p as upper bound for poly_p
1 parent cf14567 commit 97b5a41

File tree

1 file changed

+85
-83
lines changed

1 file changed

+85
-83
lines changed

src/flint/typing.py

Lines changed: 85 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,13 @@
5252
import flint as _flint
5353

5454

55-
Telem = TypeVar("Telem", bound=flint_scalar)
56-
Telem_co = TypeVar("Telem_co", bound=flint_scalar, covariant=True)
57-
Telem_coerce = TypeVar("Telem_coerce")
58-
Telem_coerce_co = TypeVar("Telem_coerce_co", covariant=True)
59-
Telem_coerce_contra = TypeVar("Telem_coerce_contra", contravariant=True)
60-
Tmpoly = TypeVar("Tmpoly", bound=flint_mpoly, covariant=True)
61-
Tctx = TypeVar("Tctx", bound=flint_mpoly_context)
62-
Sctx = TypeVar("Sctx", bound=flint_mpoly_context)
55+
_Telem = TypeVar("_Telem", bound=flint_scalar)
56+
_Telem_co = TypeVar("_Telem_co", bound=flint_scalar, covariant=True)
57+
_Telem_coerce = TypeVar("_Telem_coerce")
58+
_Telem_coerce_contra = TypeVar("_Telem_coerce_contra", contravariant=True)
59+
_Tmpoly = TypeVar("_Tmpoly", bound=flint_mpoly, covariant=True)
60+
_Tctx = TypeVar("_Tctx", bound=flint_mpoly_context)
61+
_Sctx = TypeVar("_Sctx", bound=flint_mpoly_context)
6362

6463
_str = str
6564

@@ -90,122 +89,125 @@ def __pow__(self, other: int, /) -> Self: ...
9089
def __rpow__(self, other: int, /) -> Self: ...
9190

9291

93-
class poly_p(elem_p, Protocol[Telem]):
92+
_Tscalar = TypeVar("_Tscalar", bound=scalar_p)
93+
94+
95+
class poly_p(elem_p, Protocol[_Tscalar]):
9496
"""FLINT univariate polynomial Protocol."""
9597
def str(
9698
self, ascending: bool = False, var: str = "x", *args: Any, **kwargs: Any
9799
) -> str: ...
98-
def __iter__(self) -> Iterator[Telem]: ...
99-
def __getitem__(self, index: int, /) -> Telem: ...
100-
def __setitem__(self, index: int, value: Telem | int, /) -> None: ...
100+
def __iter__(self) -> Iterator[_Tscalar]: ...
101+
def __getitem__(self, index: int, /) -> _Tscalar: ...
102+
def __setitem__(self, index: int, value: _Tscalar | int, /) -> None: ...
101103
def __len__(self) -> int: ...
102104
def length(self) -> int: ...
103105
def degree(self) -> int: ...
104-
def coeffs(self) -> list[Telem]: ...
106+
def coeffs(self) -> list[_Tscalar]: ...
105107
@overload
106-
def __call__(self, other: Telem | ifmpz, /) -> Telem: ...
108+
def __call__(self, other: _Tscalar | ifmpz, /) -> _Tscalar: ...
107109
@overload
108110
def __call__(self, other: Self, /) -> Self: ... # pyright: ignore[reportOverlappingOverload]
109111

110112
def __pos__(self) -> Self: ...
111113
def __neg__(self) -> Self: ...
112-
def __add__(self, other: Telem | ifmpz | Self, /) -> Self: ...
113-
def __radd__(self, other: Telem | ifmpz, /) -> Self: ...
114-
def __sub__(self, other: Telem | ifmpz | Self, /) -> Self: ...
115-
def __rsub__(self, other: Telem | ifmpz, /) -> Self: ...
116-
def __mul__(self, other: Telem | ifmpz | Self, /) -> Self: ...
117-
def __rmul__(self, other: Telem | ifmpz, /) -> Self: ...
118-
def __truediv__(self, other: Telem | ifmpz | Self, /) -> Self: ...
119-
def __rtruediv__(self, other: Telem | ifmpz, /) -> Self: ...
120-
def __floordiv__(self, other: Telem | ifmpz | Self, /) -> Self: ...
121-
def __rfloordiv__(self, other: Telem | ifmpz, /) -> Self: ...
122-
def __mod__(self, other: Telem | ifmpz | Self, /) -> Self: ...
123-
def __rmod__(self, other: Telem | ifmpz, /) -> Self: ...
124-
def __divmod__(self, other: Telem | ifmpz | Self, /) -> tuple[Self, Self]: ...
125-
def __rdivmod__(self, other: Telem | ifmpz, /) -> tuple[Self, Self]: ...
114+
def __add__(self, other: _Tscalar | ifmpz | Self, /) -> Self: ...
115+
def __radd__(self, other: _Tscalar | ifmpz, /) -> Self: ...
116+
def __sub__(self, other: _Tscalar | ifmpz | Self, /) -> Self: ...
117+
def __rsub__(self, other: _Tscalar | ifmpz, /) -> Self: ...
118+
def __mul__(self, other: _Tscalar | ifmpz | Self, /) -> Self: ...
119+
def __rmul__(self, other: _Tscalar | ifmpz, /) -> Self: ...
120+
def __truediv__(self, other: _Tscalar | ifmpz | Self, /) -> Self: ...
121+
def __rtruediv__(self, other: _Tscalar | ifmpz, /) -> Self: ...
122+
def __floordiv__(self, other: _Tscalar | ifmpz | Self, /) -> Self: ...
123+
def __rfloordiv__(self, other: _Tscalar | ifmpz, /) -> Self: ...
124+
def __mod__(self, other: _Tscalar | ifmpz | Self, /) -> Self: ...
125+
def __rmod__(self, other: _Tscalar | ifmpz, /) -> Self: ...
126+
def __divmod__(self, other: _Tscalar | ifmpz | Self, /) -> tuple[Self, Self]: ...
127+
def __rdivmod__(self, other: _Tscalar | ifmpz, /) -> tuple[Self, Self]: ...
126128
def __pow__(self, other: int, /) -> Self: ...
127129
def is_zero(self) -> bool: ...
128130
def is_one(self) -> bool: ...
129131
def is_constant(self) -> bool: ...
130132
def is_gen(self) -> bool: ...
131-
def roots(self) -> list[tuple[Telem, int]]: ...
133+
def roots(self) -> list[tuple[_Tscalar, int]]: ...
132134
# Should be list[arb]:
133135
def real_roots(self) -> list[Any]: ...
134136
# Should be list[acb]:
135137
def complex_roots(self) -> list[Any]: ...
136138
def derivative(self) -> Self: ...
137139

138140

139-
class epoly_p(poly_p[Telem], Protocol):
141+
class epoly_p(poly_p[_Tscalar], Protocol):
140142
"""FLINT exact univariate polynomial Protocol."""
141143
def sqrt(self) -> Self: ...
142-
def gcd(self, other: Self | Telem, /) -> Self: ...
143-
def factor(self) -> tuple[Telem, list[tuple[Self, int]]]: ...
144-
def factor_squarefree(self) -> tuple[Telem, list[tuple[Self, int]]]: ...
144+
def gcd(self, other: Self | _Tscalar, /) -> Self: ...
145+
def factor(self) -> tuple[_Tscalar, list[tuple[Self, int]]]: ...
146+
def factor_squarefree(self) -> tuple[_Tscalar, list[tuple[Self, int]]]: ...
145147
def deflation(self) -> tuple[Self, int]: ...
146148

147149

148-
class mpoly_p(elem_p, Protocol[Tctx, Telem, Telem_coerce]):
150+
class mpoly_p(elem_p, Protocol[_Tctx, _Telem, _Telem_coerce]):
149151
"""FLINT multivariate polynomial Protocol."""
150152
def __init__(
151153
self,
152154
val: Self
153-
| Telem
154-
| Telem_coerce
155+
| _Telem
156+
| _Telem_coerce
155157
| int
156-
| dict[tuple[int, ...], Telem | Telem_coerce | int]
158+
| dict[tuple[int, ...], _Telem | _Telem_coerce | int]
157159
| str = 0,
158-
ctx: Tctx | None = None,
160+
ctx: _Tctx | None = None,
159161
) -> None: ...
160162
def str(self) -> _str: ...
161163
def repr(self) -> _str: ...
162-
def context(self) -> Tctx: ...
164+
def context(self) -> _Tctx: ...
163165
def degrees(self) -> tuple[int, ...]: ...
164166
def total_degree(self) -> int: ...
165-
def leading_coefficient(self) -> Telem: ...
166-
def to_dict(self) -> dict[tuple[int, ...], Telem]: ...
167+
def leading_coefficient(self) -> _Telem: ...
168+
def to_dict(self) -> dict[tuple[int, ...], _Telem]: ...
167169
def is_one(self) -> bool: ...
168170
def is_zero(self) -> bool: ...
169171
def is_constant(self) -> bool: ...
170172
def __len__(self) -> int: ...
171-
def __getitem__(self, index: tuple[int, ...]) -> Telem: ...
173+
def __getitem__(self, index: tuple[int, ...]) -> _Telem: ...
172174
def __setitem__(
173-
self, index: tuple[int, ...], coeff: Telem | Telem_coerce | int
175+
self, index: tuple[int, ...], coeff: _Telem | _Telem_coerce | int
174176
) -> None: ...
175177
def __iter__(self) -> Iterable[tuple[int, ...]]: ...
176178
def __contains__(self, index: tuple[int, ...]) -> bool: ...
177-
def coefficient(self, i: int) -> Telem: ...
179+
def coefficient(self, i: int) -> _Telem: ...
178180
def monomial(self, i: int) -> tuple[int, ...]: ...
179-
def terms(self) -> Iterable[tuple[tuple[int, ...], Telem]]: ...
181+
def terms(self) -> Iterable[tuple[tuple[int, ...], _Telem]]: ...
180182
def monoms(self) -> list[tuple[int, ...]]: ...
181-
def coeffs(self) -> list[Telem]: ...
183+
def coeffs(self) -> list[_Telem]: ...
182184
def __pos__(self) -> Self: ...
183185
def __neg__(self) -> Self: ...
184-
def __add__(self, other: Self | Telem | Telem_coerce | int) -> Self: ...
185-
def __radd__(self, other: Telem | Telem_coerce | int) -> Self: ...
186-
def __sub__(self, other: Self | Telem | Telem_coerce | int) -> Self: ...
187-
def __rsub__(self, other: Telem | Telem_coerce | int) -> Self: ...
188-
def __mul__(self, other: Self | Telem | Telem_coerce | int) -> Self: ...
189-
def __rmul__(self, other: Telem | Telem_coerce | int) -> Self: ...
190-
def __truediv__(self, other: Self | Telem | Telem_coerce | int) -> Self: ...
191-
def __rtruediv__(self, other: Telem | Telem_coerce | int) -> Self: ...
192-
def __floordiv__(self, other: Self | Telem | Telem_coerce | int) -> Self: ...
193-
def __rfloordiv__(self, other: Telem | Telem_coerce | int) -> Self: ...
194-
def __mod__(self, other: Self | Telem | Telem_coerce | int) -> Self: ...
195-
def __rmod__(self, other: Telem | Telem_coerce | int) -> Self: ...
186+
def __add__(self, other: Self | _Telem | _Telem_coerce | int) -> Self: ...
187+
def __radd__(self, other: _Telem | _Telem_coerce | int) -> Self: ...
188+
def __sub__(self, other: Self | _Telem | _Telem_coerce | int) -> Self: ...
189+
def __rsub__(self, other: _Telem | _Telem_coerce | int) -> Self: ...
190+
def __mul__(self, other: Self | _Telem | _Telem_coerce | int) -> Self: ...
191+
def __rmul__(self, other: _Telem | _Telem_coerce | int) -> Self: ...
192+
def __truediv__(self, other: Self | _Telem | _Telem_coerce | int) -> Self: ...
193+
def __rtruediv__(self, other: _Telem | _Telem_coerce | int) -> Self: ...
194+
def __floordiv__(self, other: Self | _Telem | _Telem_coerce | int) -> Self: ...
195+
def __rfloordiv__(self, other: _Telem | _Telem_coerce | int) -> Self: ...
196+
def __mod__(self, other: Self | _Telem | _Telem_coerce | int) -> Self: ...
197+
def __rmod__(self, other: _Telem | _Telem_coerce | int) -> Self: ...
196198
def __divmod__(
197-
self, other: Self | Telem | Telem_coerce | int
199+
self, other: Self | _Telem | _Telem_coerce | int
198200
) -> tuple[Self, Self]: ...
199-
def __rdivmod__(self, other: Telem | Telem_coerce | int) -> tuple[Self, Self]: ...
200-
def __pow__(self, other: Telem | Telem_coerce | int) -> Self: ...
201-
def __rpow__(self, other: Telem | Telem_coerce | int) -> Self: ...
202-
def iadd(self, other: Telem | Telem_coerce | int) -> None: ...
203-
def isub(self, other: Telem | Telem_coerce | int) -> None: ...
204-
def imul(self, other: Telem | Telem_coerce | int) -> None: ...
201+
def __rdivmod__(self, other: _Telem | _Telem_coerce | int) -> tuple[Self, Self]: ...
202+
def __pow__(self, other: _Telem | _Telem_coerce | int) -> Self: ...
203+
def __rpow__(self, other: _Telem | _Telem_coerce | int) -> Self: ...
204+
def iadd(self, other: _Telem | _Telem_coerce | int) -> None: ...
205+
def isub(self, other: _Telem | _Telem_coerce | int) -> None: ...
206+
def imul(self, other: _Telem | _Telem_coerce | int) -> None: ...
205207
def gcd(self, other: Self) -> Self: ...
206208
def term_content(self) -> Self: ...
207-
def factor(self) -> tuple[Telem, Sequence[tuple[Self, int]]]: ...
208-
def factor_squarefree(self) -> tuple[Telem, Sequence[tuple[Self, int]]]: ...
209+
def factor(self) -> tuple[_Telem, Sequence[tuple[Self, int]]]: ...
210+
def factor_squarefree(self) -> tuple[_Telem, Sequence[tuple[Self, int]]]: ...
209211
def sqrt(self) -> Self: ...
210212
def resultant(self, other: Self, var: _str | int) -> Self: ...
211213
def discriminant(self, var: _str | int) -> Self: ...
@@ -214,32 +216,32 @@ def deflation(self) -> tuple[Self, list[int]]: ...
214216
def deflation_monom(self) -> tuple[Self, list[int], Self]: ...
215217
def inflate(self, N: list[int]) -> Self: ...
216218
def deflate(self, N: list[int]) -> Self: ...
217-
def subs(self, mapping: dict[_str | int, Telem | Telem_coerce | int]) -> Self: ...
218-
def compose(self, *args: Self, ctx: Tctx | None = None) -> Self: ...
219-
def __call__(self, *args: Telem | Telem_coerce) -> Telem: ...
219+
def subs(self, mapping: dict[_str | int, _Telem | _Telem_coerce | int]) -> Self: ...
220+
def compose(self, *args: Self, ctx: _Tctx | None = None) -> Self: ...
221+
def __call__(self, *args: _Telem | _Telem_coerce) -> _Telem: ...
220222
def derivative(self, var: _str | int) -> Self: ...
221223
def unused_gens(self) -> tuple[_str, ...]: ...
222224
def project_to_context(
223-
self, other_ctx: Tctx, mapping: dict[_str | int, _str | int] | None = None
225+
self, other_ctx: _Tctx, mapping: dict[_str | int, _str | int] | None = None
224226
) -> Self: ...
225227

226228

227229
class mpoly_context_p(
228-
elem_p, Protocol[Tmpoly, Telem_co, Telem_coerce_contra]
230+
elem_p, Protocol[_Tmpoly, _Telem_co, _Telem_coerce_contra]
229231
):
230232
"""FLINT multivariate polynomial context protocol."""
231233
def nvars(self) -> int: ...
232234
def ordering(self) -> Ordering: ...
233-
def gen(self, i: int, /) -> Tmpoly: ...
234-
def from_dict(self, d: Mapping[tuple[int, ...], Telem_coerce_contra], /) -> Tmpoly: ...
235-
def constant(self, z: Telem_coerce_contra, /) -> Tmpoly: ...
235+
def gen(self, i: int, /) -> _Tmpoly: ...
236+
def from_dict(self, d: Mapping[tuple[int, ...], _Telem_coerce_contra], /) -> _Tmpoly: ...
237+
def constant(self, z: _Telem_coerce_contra, /) -> _Tmpoly: ...
236238
def name(self, i: int, /) -> str: ...
237239
def names(self) -> tuple[str]: ...
238-
def gens(self) -> tuple[Tmpoly, ...]: ...
240+
def gens(self) -> tuple[_Tmpoly, ...]: ...
239241
def variable_to_index(self, var: str, /) -> int: ...
240242
def term(
241-
self, coeff: Telem_coerce_contra | None = None, exp_vec: Iterable[int] | None = None
242-
) -> Tmpoly: ...
243+
self, coeff: _Telem_coerce_contra | None = None, exp_vec: Iterable[int] | None = None
244+
) -> _Tmpoly: ...
243245
def drop_gens(self, gens: Iterable[str | int], /) -> Self: ...
244246
def append_gens(self, gens: Iterable[str | int], /) -> Self: ...
245247
def infer_generator_mapping(
@@ -248,17 +250,17 @@ def infer_generator_mapping(
248250
@classmethod
249251
def from_context(
250252
cls,
251-
ctx: Sctx,
253+
ctx: _Sctx,
252254
names: str | Iterable[str | tuple[str, int]] | tuple[str, int] | None = None,
253255
ordering: Ordering | str = Ordering.lex,
254-
) -> Sctx: ...
256+
) -> _Sctx: ...
255257

256258

257-
class series_p(elem_p, Protocol[Telem]):
259+
class series_p(elem_p, Protocol[_Telem]):
258260
"""FLINT univariate power series."""
259261

260-
def __iter__(self) -> Iterator[Telem]: ...
261-
def coeffs(self) -> list[Telem]: ...
262+
def __iter__(self) -> Iterator[_Telem]: ...
263+
def coeffs(self) -> list[_Telem]: ...
262264

263265

264266
if TYPE_CHECKING:

0 commit comments

Comments
 (0)