Skip to content

Commit 846d3d8

Browse files
committed
Change mpoly context constructors
1 parent f7bfbe9 commit 846d3d8

File tree

7 files changed

+108
-117
lines changed

7 files changed

+108
-117
lines changed

src/flint/flint_base/flint_base.pxd

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,5 @@ cdef class flint_mat(flint_elem):
5757
cdef class flint_series(flint_elem):
5858
pass
5959

60-
cpdef enum Ordering:
61-
lex, deglex, degrevlex
62-
63-
cdef ordering_t ordering_py_to_c(ordering: Ordering)
60+
cdef ordering_t ordering_py_to_c(ordering)
6461
cdef ordering_c_to_py(ordering_t ordering)

src/flint/flint_base/flint_base.pyx

Lines changed: 81 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@ from flint.flintlib.types.flint cimport (
77
from flint.utils.flint_exceptions import DomainError
88
from flint.flintlib.types.mpoly cimport ordering_t
99
from flint.flint_base.flint_context cimport thectx
10-
from flint.flint_base.flint_base cimport Ordering
1110
from flint.utils.typecheck cimport typecheck
1211
cimport libc.stdlib
1312

1413
from typing import Optional
14+
from collections.abc import Iterable
1515
from flint.utils.flint_exceptions import IncompatibleContextError
1616

1717
from flint.types.fmpz cimport fmpz, any_as_fmpz
1818

19+
import enum
1920

2021
FLINT_BITS = _FLINT_BITS
2122
FLINT_VERSION = _FLINT_VERSION.decode("ascii")
@@ -265,21 +266,42 @@ cdef class flint_poly(flint_elem):
265266
raise NotImplementedError("Complex roots are not supported for this polynomial")
266267

267268

269+
class Ordering(enum.StrEnum):
270+
lex = "lex"
271+
deglex = "deglex"
272+
degrevlex = "degrevlex"
273+
274+
275+
cdef ordering_t ordering_py_to_c(ordering):
276+
if ordering == Ordering.lex:
277+
return ordering_t.ORD_LEX
278+
elif ordering == Ordering.deglex:
279+
return ordering_t.ORD_DEGLEX
280+
elif ordering == Ordering.degrevlex:
281+
return ordering_t.ORD_DEGREVLEX
282+
283+
cdef ordering_c_to_py(ordering_t ordering):
284+
if ordering == ordering_t.ORD_LEX:
285+
return Ordering.lex
286+
elif ordering == ordering_t.ORD_DEGLEX:
287+
return Ordering.deglex
288+
elif ordering == ordering_t.ORD_DEGREVLEX:
289+
return Ordering.degrevlex
290+
else:
291+
raise ValueError("unimplemented term order %d" % ordering)
292+
293+
268294
cdef class flint_mpoly_context(flint_elem):
269295
"""
270296
Base class for multivariate ring contexts
271297
"""
272298

273299
_ctx_cache = None
274300

275-
def __init__(self, int nvars, names):
276-
if nvars < 0:
277-
raise ValueError("cannot have a negative amount of variables")
278-
elif len(names) != nvars:
279-
raise ValueError("number of variables must match number of variable names")
301+
def __init__(self, names: Iterable[str]):
280302
self.py_names = tuple(name.encode("ascii") if not isinstance(name, bytes) else name for name in names)
281-
self.c_names = <const char**> libc.stdlib.malloc(nvars * sizeof(const char *))
282-
for i in range(nvars):
303+
self.c_names = <const char**> libc.stdlib.malloc(len(names) * sizeof(const char *))
304+
for i in range(len(names)):
283305
self.c_names[i] = self.py_names[i]
284306

285307
def __dealloc__(self):
@@ -292,18 +314,18 @@ cdef class flint_mpoly_context(flint_elem):
292314
def __repr__(self):
293315
return f"{self.__class__.__name__}({self.nvars()}, '{repr(self.ordering())}', {self.names()})"
294316

295-
def name(self, long i):
317+
def name(self, i: int):
296318
if not 0 <= i < len(self.py_names):
297319
raise IndexError("variable name index out of range")
298320
return self.py_names[i].decode("ascii")
299321

300-
def names(self):
322+
def names(self) -> tuple[str]:
301323
return tuple(name.decode("ascii") for name in self.py_names)
302324

303325
def gens(self):
304326
return tuple(self.gen(i) for i in range(self.nvars()))
305327

306-
def variable_to_index(self, var: Union[int, str]):
328+
def variable_to_index(self, var: Union[int, str]) -> int:
307329
"""Convert a variable name string or possible index to its index in the context."""
308330
if isinstance(var, str):
309331
try:
@@ -320,48 +342,55 @@ cdef class flint_mpoly_context(flint_elem):
320342
return i
321343

322344
@staticmethod
323-
def create_variable_names(slong nvars, names: str):
345+
def create_variable_names(names: Iterable[str | tuple[str, int]]) -> tuple[str]:
324346
"""
325-
Create a tuple of variable names based on the comma separated ``names`` string.
326-
327-
If ``names`` contains a single value, and ``nvars`` > 1, then the variables are numbered, e.g.
347+
Create a tuple of variable names based off either ``Iterable[str]``,
348+
``tuple[str, int]``, or ``Iterable[tuple[str, int]]``.
328349

329-
>>> flint_mpoly_context.create_variable_names(3, "x")
350+
>>> flint_mpoly_context.create_variable_names([('x', 3), 'y'])
351+
('x0', 'x1', 'x2', 'y')
352+
>>> flint_mpoly_context.create_variable_names(('x', 3))
330353
('x0', 'x1', 'x2')
331-
332354
"""
333-
nametup = tuple(name.strip() for name in names.split(','))
334-
if len(nametup) != nvars:
335-
if len(nametup) == 1:
336-
nametup = tuple(nametup[0] + str(i) for i in range(nvars))
355+
res: list[str] = []
356+
357+
# Provide a convenience method to avoid having to pass a nested tuple
358+
if len(names) == 2 and isinstance(names[0], str) and isinstance(names[1], int):
359+
names = (names,)
360+
361+
for name in names:
362+
if isinstance(name, str):
363+
res.append(name)
337364
else:
338-
raise ValueError("number of variables does not equal number of names")
339-
return nametup
365+
base, num = name
366+
if num < 0:
367+
raise ValueError("cannot create a negative number of variables")
368+
res.extend(base + str(i) for i in range(num))
369+
370+
return tuple(res)
340371

341372
@classmethod
342-
def create_context_key(cls, slong nvars=1, ordering=Ordering.lex, names: Optional[str] = "x", nametup: Optional[tuple] = None):
373+
def create_context_key(
374+
cls,
375+
names: Iterable[str | tuple[str, int]],
376+
ordering: Ordering | str = Ordering.lex
377+
):
343378
"""
344-
Create a key for the context cache via the number of variables, the ordering, and
345-
either a variable name string, or a tuple of variable names.
379+
Create a key for the context cache via the variable names and the ordering.
346380
"""
347381
# A type hint of ``ordering: Ordering`` results in the error "TypeError: an integer is required" if a Ordering
348382
# object is not provided. This is pretty obtuse so we check its type ourselves
349-
if not isinstance(ordering, Ordering):
350-
raise TypeError(f"'ordering' ('{ordering}') is not an instance of flint.Ordering")
383+
# if not isinstance(ordering, Ordering):
384+
# raise TypeError(f"'ordering' ('{ordering}') is not an instance of flint.Ordering")
351385

352-
if nametup is not None:
353-
key = nvars, ordering, nametup
354-
elif nametup is None and names is not None:
355-
key = nvars, ordering, cls.create_variable_names(nvars, names)
356-
else:
357-
raise ValueError("must provide either 'names' or 'nametup'")
358-
return key
386+
return cls.create_variable_names(names), Ordering(ordering) if not isinstance(ordering, Ordering) else ordering
359387

360388
@classmethod
361389
def get_context(cls, *args, **kwargs):
362390
"""
363-
Retrieve a context via the number of variables, ``nvars``, the ordering, ``ordering``, and either a variable
364-
name string, ``names``, or a tuple of variable names, ``nametup``.
391+
Retrieve or create a context via generator names, ``names`` and the ordering, ``ordering``.
392+
393+
See ``create_variable_names`` for naming schemes.
365394
"""
366395
key = cls.create_context_key(*args, **kwargs)
367396

@@ -373,10 +402,8 @@ cdef class flint_mpoly_context(flint_elem):
373402
@classmethod
374403
def from_context(cls, ctx: flint_mpoly_context):
375404
return cls.get_context(
376-
nvars=ctx.nvars(),
377405
ordering=ctx.ordering(),
378-
names=None,
379-
nametup=ctx.names()
406+
names=ctx.names(),
380407
)
381408

382409
def _any_as_scalar(self, other):
@@ -410,35 +437,29 @@ cdef class flint_mpoly_context(flint_elem):
410437
return self.from_dict({tuple(exp_vec): coeff})
411438

412439
cdef class flint_mod_mpoly_context(flint_mpoly_context):
413-
def __init__(self, nvars, names, prime_modulus):
414-
super().__init__(nvars, names)
440+
def __init__(self, names, prime_modulus):
441+
super().__init__(names)
415442
self.__prime_modulus = <bint>prime_modulus
416443

417444
@classmethod
418445
def create_context_key(
419446
cls,
420-
slong nvars=1,
421-
ordering=Ordering.lex,
422-
modulus = None,
423-
names: Optional[str] = "x",
424-
nametup: Optional[tuple] = None,
447+
names: Iterable[str | tuple[str, int]],
448+
modulus,
449+
ordering: Ordering | str = Ordering.lex
425450
):
426451
"""
427-
Create a key for the context cache via the number of variables, the ordering, the modulus, and either a
428-
variable name string, or a tuple of variable names.
452+
Create a key for the context cache via the variable names, modulus, and the ordering.
429453
"""
430-
# A type hint of ``ordering: Ordering`` results in the error "TypeError: an integer is required" if a Ordering
431-
# object is not provided. This is pretty obtuse so we check its type ourselves
432-
if not isinstance(ordering, Ordering):
433-
raise TypeError(f"'ordering' ('{ordering}') is not an instance of flint.Ordering")
454+
return *super().create_context_key(names, ordering), modulus
434455

435-
if nametup is not None:
436-
key = nvars, ordering, nametup, modulus
437-
elif nametup is None and names is not None:
438-
key = nvars, ordering, cls.create_variable_names(nvars, names), modulus
439-
else:
440-
raise ValueError("must provide either 'names' or 'nametup'")
441-
return key
456+
@classmethod
457+
def from_context(cls, ctx: flint_mod_mpoly_context):
458+
return cls.get_context(
459+
names=ctx.names(),
460+
modulus=ctx.modulus(),
461+
ordering=ctx.ordering(),
462+
)
442463

443464
def is_prime(self):
444465
"""
@@ -898,26 +919,3 @@ cdef class flint_mat(flint_elem):
898919

899920
# supports mpmath conversions
900921
tolist = table
901-
902-
903-
cdef ordering_t ordering_py_to_c(ordering): # Cython does not like an "Ordering" type hint here
904-
if not isinstance(ordering, Ordering):
905-
raise TypeError(f"'ordering' ('{ordering}') is not an instance of flint.Ordering")
906-
907-
if ordering == Ordering.lex:
908-
return ordering_t.ORD_LEX
909-
elif ordering == Ordering.deglex:
910-
return ordering_t.ORD_DEGLEX
911-
elif ordering == Ordering.degrevlex:
912-
return ordering_t.ORD_DEGREVLEX
913-
914-
915-
cdef ordering_c_to_py(ordering_t ordering):
916-
if ordering == ordering_t.ORD_LEX:
917-
return Ordering.lex
918-
elif ordering == ordering_t.ORD_DEGLEX:
919-
return Ordering.deglex
920-
elif ordering == ordering_t.ORD_DEGREVLEX:
921-
return Ordering.degrevlex
922-
else:
923-
raise ValueError("unimplemented term order %d" % ordering)

src/flint/test/test_all.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2859,14 +2859,10 @@ def test_mpolys():
28592859
# division is exact or not.
28602860
composite_characteristic = characteristic != 0 and not characteristic.is_prime()
28612861

2862-
ctx = get_context(nvars=2)
2862+
ctx = get_context((("x", 2),))
28632863

2864-
assert raises(lambda: get_context(nvars=2, ordering="bad"), TypeError)
2865-
assert raises(lambda: get_context(nvars=-1), ValueError)
2866-
if ctx.__class__ is flint.fmpz_mod_mpoly_ctx or ctx.__class__ is flint.nmod_mpoly_ctx:
2867-
assert raises(lambda: ctx.__class__(-1, flint.Ordering.lex, [], 4), ValueError)
2868-
else:
2869-
assert raises(lambda: ctx.__class__(-1, flint.Ordering.lex, []), ValueError)
2864+
assert raises(lambda: get_context((("x", 2),), ordering="bad"), ValueError)
2865+
assert raises(lambda: get_context((("x", -1),)), ValueError)
28702866
assert raises(lambda: ctx.constant("bad"), TypeError)
28712867
assert raises(lambda: ctx.from_dict("bad"), ValueError)
28722868
assert raises(lambda: ctx.from_dict({(0, 0): "bad"}), TypeError)
@@ -2875,7 +2871,7 @@ def test_mpolys():
28752871
assert raises(lambda: ctx.gen(-1), IndexError)
28762872
assert raises(lambda: ctx.gen(10), IndexError)
28772873

2878-
assert raises(lambda: P(val=get_context(nvars=1).constant(0), ctx=ctx), IncompatibleContextError)
2874+
assert raises(lambda: P(val=get_context(("x",)).constant(0), ctx=ctx), IncompatibleContextError)
28792875
assert raises(lambda: P(val={}, ctx=None), ValueError)
28802876
assert raises(lambda: P(val={"bad": 1}, ctx=None), ValueError)
28812877
assert raises(lambda: P(val="1", ctx=None), ValueError)
@@ -2894,10 +2890,10 @@ def quick_poly():
28942890
assert ctx.nvars() == 2
28952891
assert ctx.ordering() == flint.Ordering.lex
28962892

2897-
ctx1 = get_context(4)
2893+
ctx1 = get_context((("x", 4),))
28982894
assert [ctx1.name(i) for i in range(4)] == ['x0', 'x1', 'x2', 'x3']
28992895
for order in list(flint.Ordering):
2900-
ctx1 = get_context(4, order)
2896+
ctx1 = get_context((("x", 4),), ordering=order)
29012897
assert ctx1.ordering() == order
29022898

29032899
assert ctx.constant(1) == mpoly({(0, 0): 1}) == P(1, ctx=ctx)
@@ -2948,7 +2944,7 @@ def quick_poly():
29482944
assert P({(0, 1): 3}, ctx=ctx) == ctx.from_dict({(0, 1): 3})
29492945

29502946
if P is flint.fmpq_mpoly:
2951-
ctx_z = flint.fmpz_mpoly_ctx.get_context(2)
2947+
ctx_z = flint.fmpz_mpoly_ctx.get_context((("x", 2),))
29522948
assert quick_poly() == P(ctx_z.from_dict({(0, 0): 1, (0, 1): 2, (1, 0): 3, (2, 2): 4}))
29532949
assert P(ctx_z.from_dict({(0, 0): 1}), ctx=ctx) == P({(0, 0): 1}, ctx=ctx)
29542950

@@ -2997,8 +2993,8 @@ def quick_poly():
29972993

29982994
assert raises(lambda: p.__setitem__((2, 1), None), TypeError)
29992995

3000-
assert P(ctx=ctx).repr() == f"{ctx.__class__.__name__}(2, '<Ordering.lex: 0>', ('x0', 'x1')).from_dict({{}})"
3001-
assert P(1, ctx=ctx).repr() == f"{ctx.__class__.__name__}(2, '<Ordering.lex: 0>', ('x0', 'x1')).from_dict({{(0, 0): 1}})"
2996+
assert P(ctx=ctx).repr() == f"{ctx.__class__.__name__}(2, '<Ordering.lex: 'lex'>', ('x0', 'x1')).from_dict({{}})"
2997+
assert P(1, ctx=ctx).repr() == f"{ctx.__class__.__name__}(2, '<Ordering.lex: 'lex'>', ('x0', 'x1')).from_dict({{(0, 0): 1}})"
30022998
assert str(quick_poly()) == repr(quick_poly()) == '4*x0^2*x1^2 + 3*x0 + 2*x1 + 1'
30032999

30043000
assert p.monomial(0) == (2, 2)
@@ -3039,7 +3035,7 @@ def quick_poly():
30393035
assert raises(lambda: p.subs({"a": 1}), ValueError)
30403036
assert raises(lambda: p.subs({"x0": 0, "x1": 1, "x2": 2}), ValueError)
30413037

3042-
no_gens_ctx = get_context(0)
3038+
no_gens_ctx = get_context(tuple())
30433039
no_gens_p = P("2", no_gens_ctx)
30443040
assert no_gens_p.compose(ctx=ctx1).context() is ctx1
30453041
assert raises(lambda: no_gens_p.compose(), ValueError)
@@ -3318,8 +3314,8 @@ def test_fmpz_mpoly_vec():
33183314
for context, mpoly_vec in _all_mpoly_vecs():
33193315
has_groebner_functions = mpoly_vec is flint.fmpz_mpoly_vec
33203316

3321-
ctx = context.get_context(nvars=2)
3322-
ctx1 = context.get_context(nvars=4)
3317+
ctx = context.get_context((("x", 2),))
3318+
ctx1 = context.get_context((("x", 4),))
33233319
x, y = ctx.gens()
33243320

33253321
vec = mpoly_vec(3, ctx)
@@ -3348,7 +3344,7 @@ def test_fmpz_mpoly_vec():
33483344
assert raises(lambda: vec.__setitem__(0, ctx1.from_dict({})), IncompatibleContextError)
33493345

33503346
if has_groebner_functions:
3351-
ctx = context.get_context(3, flint.Ordering.lex, nametup=('x', 'y', 'z'))
3347+
ctx = context.get_context(("x", "y", "z"))
33523348

33533349
# Examples here cannibalised from
33543350
# https://en.wikipedia.org/wiki/Gr%C3%B6bner_basis#Example_and_counterexample
@@ -3412,7 +3408,7 @@ def _all_polys_mpolys():
34123408
yield P, S, [x, y], is_field, characteristic
34133409

34143410
for P, get_context, S, is_field, characteristic in _all_mpolys():
3415-
ctx = get_context(2, flint.Ordering.lex, nametup=("x", "y"))
3411+
ctx = get_context(("x", "y"))
34163412
x, y = ctx.gens()
34173413
assert isinstance(x, (
34183414
flint.fmpz_mpoly,

src/flint/types/fmpq_mpoly.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ cdef class fmpq_mpoly_ctx(flint_mpoly_context):
9898

9999
_ctx_cache = _fmpq_mpoly_ctx_cache
100100

101-
def __init__(self, slong nvars, ordering, names):
102-
fmpq_mpoly_ctx_init(self.val, nvars, ordering_py_to_c(ordering))
103-
super().__init__(nvars, names)
101+
def __init__(self, names, ordering):
102+
super().__init__(names)
103+
fmpq_mpoly_ctx_init(self.val, len(names), ordering_py_to_c(ordering))
104104

105105
def _any_as_scalar(self, other):
106106
if isinstance(other, int):

0 commit comments

Comments
 (0)