Skip to content

Commit c073d92

Browse files
committed
Disable __init__ for mpoly contexts
1 parent c2b1089 commit c073d92

File tree

6 files changed

+57
-17
lines changed

6 files changed

+57
-17
lines changed

src/flint/flint_base/flint_base.pyx

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -298,11 +298,31 @@ cdef class flint_mpoly_context(flint_elem):
298298

299299
_ctx_cache = None
300300

301-
def __init__(self, names: Iterable[str]):
301+
def __init__(self, *_, **_2):
302+
raise RuntimeError(
303+
f"{self.__class__.__name__} should not be constructed directly. "
304+
f"Use '{self.__class__.__name__}.get' instead."
305+
)
306+
307+
@classmethod
308+
def _new_(_, flint_mpoly_context self, names: Iterable[str]):
309+
"""
310+
Constructor for all mpoly context types. This method is not intended for
311+
user-face use. See ``get`` instead.
312+
313+
Construction via ``__init__`` is disabled to prevent the accidental creation of
314+
new mpoly contexts. By ensuring each context is unique they can be compared via
315+
pointer comparisons.
316+
317+
Each concrete subclass should maintain their own context cache in
318+
``_ctx_cache``, and the ``get`` method should insert newly created contexts into
319+
the cache.
320+
"""
302321
self.py_names = tuple(name.encode("ascii") if not isinstance(name, bytes) else name for name in names)
303322
self.c_names = <const char**> libc.stdlib.malloc(len(names) * sizeof(const char *))
304323
for i in range(len(names)):
305324
self.c_names[i] = self.py_names[i]
325+
return self
306326

307327
def __dealloc__(self):
308328
libc.stdlib.free(self.c_names)
@@ -342,15 +362,17 @@ cdef class flint_mpoly_context(flint_elem):
342362
return i
343363

344364
@staticmethod
345-
def create_variable_names(names: Iterable[str | tuple[str, int]]) -> tuple[str]:
365+
def create_variable_names(names: str | Iterable[str | tuple[str, int]]) -> tuple[str]:
346366
"""
347-
Create a tuple of variable names based off either ``Iterable[str]``,
367+
Create a tuple of variable names based off either ``str``, ``Iterable[str]``,
348368
``tuple[str, int]``, or ``Iterable[tuple[str, int]]``.
349369

350-
>>> flint_mpoly_context.create_variable_names([('x', 3), 'y'])
351-
('x0', 'x1', 'x2', 'y')
370+
>>> flint_mpoly_context.create_variable_names('x')
371+
('x',)
352372
>>> flint_mpoly_context.create_variable_names(('x', 3))
353373
('x0', 'x1', 'x2')
374+
>>> flint_mpoly_context.create_variable_names([('x', 3), 'y'])
375+
('x0', 'x1', 'x2', 'y')
354376
"""
355377
res: list[str] = []
356378

@@ -372,7 +394,7 @@ cdef class flint_mpoly_context(flint_elem):
372394
@classmethod
373395
def create_context_key(
374396
cls,
375-
names: Iterable[str | tuple[str, int]],
397+
names: str | Iterable[str | tuple[str, int]],
376398
ordering: Ordering | str = Ordering.lex
377399
):
378400
"""
@@ -391,7 +413,7 @@ cdef class flint_mpoly_context(flint_elem):
391413

392414
ctx = cls._ctx_cache.get(key)
393415
if ctx is None:
394-
ctx = cls._ctx_cache.setdefault(key, cls(*key))
416+
ctx = cls._ctx_cache.setdefault(key, cls._new_(*key))
395417
return ctx
396418

397419
@classmethod
@@ -432,10 +454,13 @@ cdef class flint_mpoly_context(flint_elem):
432454
return self.from_dict({tuple(exp_vec): coeff})
433455

434456
cdef class flint_mod_mpoly_context(flint_mpoly_context):
435-
def __init__(self, names, prime_modulus):
436-
super().__init__(names)
457+
@classmethod
458+
def _new_(_, flint_mod_mpoly_context self, names, prime_modulus):
459+
super()._new_(self, names)
437460
self.__prime_modulus = <bint>prime_modulus
438461

462+
return self
463+
439464
@classmethod
440465
def create_context_key(
441466
cls,

src/flint/test/test_all.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2861,6 +2861,7 @@ def test_mpolys():
28612861

28622862
ctx = get_context((("x", 2),))
28632863

2864+
assert raises(lambda : ctx.__class__("x", flint.Ordering.lex), RuntimeError)
28642865
assert raises(lambda: get_context((("x", 2),), ordering="bad"), ValueError)
28652866
assert raises(lambda: get_context((("x", -1),)), ValueError)
28662867
assert raises(lambda: ctx.constant("bad"), TypeError)

src/flint/types/fmpq_mpoly.pyx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,14 @@ cdef class fmpq_mpoly_ctx(flint_mpoly_context):
9797

9898
_ctx_cache = _fmpq_mpoly_ctx_cache
9999

100-
def __init__(self, names, ordering):
101-
super().__init__(names)
100+
@classmethod
101+
def _new_(cls, names, ordering):
102+
cdef fmpq_mpoly_ctx self = cls.__new__(cls)
103+
super()._new_(self, names)
102104
fmpq_mpoly_ctx_init(self.val, len(names), ordering_py_to_c(ordering))
103105

106+
return self
107+
104108
def _any_as_scalar(self, other):
105109
if isinstance(other, int):
106110
return any_as_fmpq(other)

src/flint/types/fmpz_mod_mpoly.pyx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ cdef class fmpz_mod_mpoly_ctx(flint_mod_mpoly_context):
8989

9090
_ctx_cache = _fmpz_mod_mpoly_ctx_cache
9191

92-
def __init__(self, names, ordering, modulus):
92+
@classmethod
93+
def _new_(cls, names, ordering, modulus):
94+
cdef fmpz_mod_mpoly_ctx self = cls.__new__(cls)
9395
cdef fmpz m
9496
if not typecheck(modulus, fmpz):
9597
m = any_as_fmpz(modulus)
@@ -98,8 +100,9 @@ cdef class fmpz_mod_mpoly_ctx(flint_mod_mpoly_context):
98100
else:
99101
m = modulus
100102

101-
super().__init__(names, m.is_prime())
103+
super()._new_(self, names, m.is_prime())
102104
fmpz_mod_mpoly_ctx_init(self.val, len(names), ordering_py_to_c(ordering), m.val)
105+
return self
103106

104107
def _any_as_scalar(self, other):
105108
if isinstance(other, int):

src/flint/types/fmpz_mpoly.pyx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,12 @@ cdef class fmpz_mpoly_ctx(flint_mpoly_context):
101101

102102
_ctx_cache = _fmpz_mpoly_ctx_cache
103103

104-
def __init__(self, names, ordering):
105-
super().__init__(names)
104+
@classmethod
105+
def _new_(cls, names, ordering):
106+
cdef fmpz_mpoly_ctx self = cls.__new__(cls)
107+
super()._new_(self, names)
106108
fmpz_mpoly_ctx_init(self.val, len(names), ordering_py_to_c(ordering))
109+
return self
107110

108111
def _any_as_scalar(self, other):
109112
if isinstance(other, int):

src/flint/types/nmod_mpoly.pyx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,16 @@ cdef class nmod_mpoly_ctx(flint_mod_mpoly_context):
9494

9595
_ctx_cache = _nmod_mpoly_ctx_cache
9696

97-
def __init__(self, names, ordering, modulus: int):
97+
@classmethod
98+
def _new_(cls, names, ordering, modulus: int):
99+
cdef nmod_mpoly_ctx self = cls.__new__(cls)
100+
98101
if modulus <= 0:
99102
raise ValueError("modulus must be positive")
100103

101-
super().__init__(names, <bint>n_is_prime(modulus))
104+
super()._new_(self, names, <bint>n_is_prime(modulus))
102105
nmod_mpoly_ctx_init(self.val, len(names), ordering_py_to_c(ordering), modulus)
106+
return self
103107

104108
def _any_as_scalar(self, other):
105109
if isinstance(other, int):

0 commit comments

Comments
 (0)