@@ -5,49 +5,83 @@ from flint.types.fmpz cimport fmpz, any_as_fmpz
55from  flint.types.fmpz_poly cimport any_as_fmpz_poly
66from  flint.types.fmpz_poly cimport fmpz_poly
77from  flint.types.nmod cimport any_as_nmod_ctx
8- from  flint.types.nmod cimport nmod
8+ from  flint.types.nmod cimport nmod, nmod_ctx 
99
1010from  flint.flintlib.nmod_vec cimport * 
1111from  flint.flintlib.nmod_poly cimport * 
1212from  flint.flintlib.nmod_poly_factor cimport * 
1313from  flint.flintlib.fmpz_poly cimport fmpz_poly_get_nmod_poly
14- from  flint.flintlib.ulong_extras cimport n_gcdinv
14+ from  flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime 
1515
1616from  flint.utils.flint_exceptions import  DomainError
1717
1818
19- cdef any_as_nmod_poly(obj, nmod_t mod):
20-     cdef nmod_poly r
21-     cdef mp_limb_t v
22-     #  XXX: should check that modulus is the same here, and not all over the place
23-     if  typecheck(obj, nmod_poly):
19+ _nmod_poly_ctx_cache =  {}
20+ 
21+ 
22+ cdef nmod_ctx any_as_nmod_poly_ctx(obj):
23+     """ Convert an int to an nmod_ctx.""" 
24+     if  typecheck(obj, nmod_poly_ctx):
2425        return  obj
25-     if  any_as_nmod(& v, obj, mod):
26-         r =  nmod_poly.__new__ (nmod_poly)
27-         nmod_poly_init(r.val, mod.n)
28-         nmod_poly_set_coeff_ui(r.val, 0 , v)
29-         return  r
30-     x =  any_as_fmpz_poly(obj)
31-     if  x is  not  NotImplemented :
32-         r =  nmod_poly.__new__ (nmod_poly)
33-         nmod_poly_init(r.val, mod.n)   #  XXX: create flint _nmod_poly_set_modulus for this?
34-         fmpz_poly_get_nmod_poly(r.val, (< fmpz_poly> x).val)
35-         return  r
26+     if  typecheck(obj, int ):
27+         ctx =  _nmod_poly_ctx_cache.get(obj)
28+         if  ctx is  None :
29+             ctx =  nmod_poly_ctx(obj)
30+             _nmod_poly_ctx_cache[obj] =  ctx
31+         return  ctx
3632    return  NotImplemented 
3733
38- cdef nmod_poly_set_list(nmod_poly_t poly, list  val):
39-     cdef long  i, n
40-     cdef nmod_t mod
41-     cdef mp_limb_t v
42-     nmod_init(& mod, nmod_poly_modulus(poly)) #  XXX
43-     n =  PyList_GET_SIZE(val)
44-     nmod_poly_fit_length(poly, n)
45-     for  i from  0  <=  i <  n:
46-         c =  val[i]
47-         if  any_as_nmod(& v, val[i], mod):
48-             nmod_poly_set_coeff_ui(poly, i, v)
49-         else :
50-             raise  TypeError (" unsupported coefficient in list"  )
34+ 
35+ cdef class  nmod_poly_ctx:
36+     """ 
37+     Context object for creating :class:`~.nmod_poly` initalised  
38+     with modulus :math:`N`. 
39+ 
40+         >>> nmod_ctx(17) 
41+         nmod_ctx(17) 
42+ 
43+     """  
44+     def  __init__  (self , mod ):
45+         cdef mp_limb_t m
46+         m =  mod
47+         nmod_init(& self .mod, m)
48+         self .ctx =  nmod_ctx(mod)
49+         self ._is_prime =  n_is_prime(m)
50+ 
51+     cdef int  any_as_nmod(self , mp_limb_t *  val, obj) except  - 1 :
52+         return  self .ctx.any_as_nmod(val, obj)
53+ 
54+     cdef any_as_nmod_poly(self , obj):
55+         cdef nmod_poly r
56+         cdef mp_limb_t v
57+         #  XXX: should check that modulus is the same here, and not all over the place
58+         if  typecheck(obj, nmod_poly):
59+             return  obj
60+         if  self .ctx.any_as_nmod(& v, obj):
61+             r =  nmod_poly.__new__ (nmod_poly)
62+             nmod_poly_init(r.val, self .mod.n)
63+             nmod_poly_set_coeff_ui(r.val, 0 , v)
64+             return  r
65+         x =  any_as_fmpz_poly(obj)
66+         if  x is  not  NotImplemented :
67+             r =  nmod_poly.__new__ (nmod_poly)
68+             nmod_poly_init(r.val, self .mod.n)   #  XXX: create flint _nmod_poly_set_modulus for this?
69+             fmpz_poly_get_nmod_poly(r.val, (< fmpz_poly> x).val)
70+             return  r
71+         return  NotImplemented 
72+ 
73+     cdef nmod_poly_set_list(self , nmod_poly_t poly, list  val):
74+         cdef long  i, n
75+         cdef mp_limb_t v
76+         n =  PyList_GET_SIZE(val)
77+         nmod_poly_fit_length(poly, n)
78+         for  i from  0  <=  i <  n:
79+             c =  val[i]
80+             if  self .any_as_nmod(& v, val[i]):
81+                 nmod_poly_set_coeff_ui(poly, i, v)
82+             else :
83+                 raise  TypeError (" unsupported coefficient in list"  )
84+ 
5185
5286cdef class  nmod_poly(flint_poly):
5387    """ 
@@ -79,24 +113,32 @@ cdef class nmod_poly(flint_poly):
79113    def  __dealloc__ (self ):
80114        nmod_poly_clear(self .val)
81115
82-     def  __init__  (self , val = None , ulong  mod = 0 ):
116+     def  __init__  (self , val = None , mod = 0 ):
83117        cdef ulong m2
84118        cdef mp_limb_t v
119+         cdef nmod_poly_ctx ctx
120+ 
85121        if  typecheck(val, nmod_poly):
86122            m2 =  nmod_poly_modulus((< nmod_poly> val).val)
87123            if  m2 !=  mod:
88124                raise  ValueError (" different moduli!"  )
89125            nmod_poly_init(self .val, m2)
90126            nmod_poly_set(self .val, (< nmod_poly> val).val)
127+             self .ctx =  (< nmod_poly> val).ctx
91128        else :
92129            if  mod ==  0 :
93130                raise  ValueError (" a nonzero modulus is required"  )
94-             nmod_poly_init(self .val, mod)
131+             ctx =  any_as_nmod_poly_ctx(mod)
132+             if  ctx is  NotImplemented :
133+                 raise  TypeError (" cannot create nmod_poly_ctx from input of type %s "  , type (mod))
134+ 
135+             self .ctx =  ctx
136+             nmod_poly_init(self .val, ctx.mod.n)
95137            if  typecheck(val, fmpz_poly):
96138                fmpz_poly_get_nmod_poly(self .val, (< fmpz_poly> val).val)
97139            elif  typecheck(val, list ):
98-                 nmod_poly_set_list(self .val, val)
99-             elif  any_as_nmod(& v, val,  self .val.mod ):
140+                 ctx. nmod_poly_set_list(self .val, val)
141+             elif  ctx. any_as_nmod(& v, val):
100142                nmod_poly_fit_length(self .val, 1 )
101143                nmod_poly_set_coeff_ui(self .val, 0 , v)
102144            else :
@@ -178,7 +220,7 @@ cdef class nmod_poly(flint_poly):
178220        cdef mp_limb_t v
179221        if  i <  0 :
180222            raise  ValueError (" cannot assign to index < 0 of polynomial"  )
181-         if  any_as_nmod(& v, x,  self .val.mod ):
223+         if  self .ctx. any_as_nmod(& v, x):
182224            nmod_poly_set_coeff_ui(self .val, i, v)
183225        else :
184226            raise  TypeError (" cannot set element of type %s "   %  type (x))
@@ -291,7 +333,7 @@ cdef class nmod_poly(flint_poly):
291333            9*x^4 + 12*x^3 + 10*x^2 + 4*x + 1 
292334        """  
293335        cdef nmod_poly res
294-         other =  any_as_nmod_poly(other, ( < nmod_poly > self ).val.mod )
336+         other =  self .ctx.any_as_nmod_poly(other )
295337        if  other is  NotImplemented :
296338            raise  TypeError (" cannot convert input to nmod_poly"  )
297339        res =  nmod_poly.__new__ (nmod_poly)
@@ -316,11 +358,11 @@ cdef class nmod_poly(flint_poly):
316358            147*x^3 + 159*x^2 + 4*x + 7 
317359        """  
318360        cdef nmod_poly res
319-         g =  any_as_nmod_poly(other,  self .val.mod )
361+         g =  self .ctx.any_as_nmod_poly(other )
320362        if  g is  NotImplemented :
321363            raise  TypeError (f" cannot convert {other = } to nmod_poly"  )
322364
323-         h =  any_as_nmod_poly(modulus,  self .val.mod )
365+         h =  self . any_as_nmod_poly(modulus)
324366        if  h is  NotImplemented :
325367            raise  TypeError (f" cannot convert {modulus = } to nmod_poly"  )
326368
@@ -334,11 +376,11 @@ cdef class nmod_poly(flint_poly):
334376
335377    def  __call__  (self , other ):
336378        cdef mp_limb_t c
337-         if  any_as_nmod(& c, other,  self .val.mod ):
379+         if  self .ctx. any_as_nmod(& c, other):
338380            v =  nmod(0 , self .modulus())
339381            (< nmod> v).val =  nmod_poly_evaluate_nmod(self .val, c)
340382            return  v
341-         t =  any_as_nmod_poly(other,  self .val.mod )
383+         t =  self .ctx.any_as_nmod_poly(other )
342384        if  t is  not  NotImplemented :
343385            r =  nmod_poly.__new__ (nmod_poly)
344386            nmod_poly_init_preinv((< nmod_poly> r).val, self .val.mod.n, self .val.mod.ninv)
@@ -369,7 +411,7 @@ cdef class nmod_poly(flint_poly):
369411
370412    def  _add_ (s , t ):
371413        cdef nmod_poly r
372-         t =  any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
414+         t =  s.ctx.any_as_nmod_poly(t )
373415        if  t is  NotImplemented :
374416            return  t
375417        if  (< nmod_poly> s).val.mod.n !=  (< nmod_poly> t).val.mod.n:
@@ -395,20 +437,20 @@ cdef class nmod_poly(flint_poly):
395437        return  r
396438
397439    def  __sub__  (s , t ):
398-         t =  any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
440+         t =  s.ctx.any_as_nmod_poly(t )
399441        if  t is  NotImplemented :
400442            return  t
401443        return  s._sub_(t)
402444
403445    def  __rsub__  (s , t ):
404-         t =  any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
446+         t =  s. any_as_nmod_poly(t)
405447        if  t is  NotImplemented :
406448            return  t
407449        return  t._sub_(s)
408450
409451    def  _mul_ (s , t ):
410452        cdef nmod_poly r
411-         t =  any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
453+         t =  s. any_as_nmod_poly(t)
412454        if  t is  NotImplemented :
413455            return  t
414456        if  (< nmod_poly> s).val.mod.n !=  (< nmod_poly> t).val.mod.n:
@@ -425,7 +467,7 @@ cdef class nmod_poly(flint_poly):
425467        return  s._mul_(t)
426468
427469    def  __truediv__  (s , t ):
428-         t =  any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
470+         t =  s. any_as_nmod_poly(t)
429471        if  t is  NotImplemented :
430472            return  t
431473        res, r =  s._divmod_(t)
@@ -434,7 +476,7 @@ cdef class nmod_poly(flint_poly):
434476        return  res
435477
436478    def  __rtruediv__  (s , t ):
437-         t =  any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
479+         t =  s. any_as_nmod_poly(t)
438480        if  t is  NotImplemented :
439481            return  t
440482        res, r =  t._divmod_(s)
@@ -454,13 +496,13 @@ cdef class nmod_poly(flint_poly):
454496        return  r
455497
456498    def  __floordiv__  (s , t ):
457-         t =  any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
499+         t =  s. any_as_nmod_poly(t)
458500        if  t is  NotImplemented :
459501            return  t
460502        return  s._floordiv_(t)
461503
462504    def  __rfloordiv__  (s , t ):
463-         t =  any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
505+         t =  s. any_as_nmod_poly(t)
464506        if  t is  NotImplemented :
465507            return  t
466508        return  t._floordiv_(s)
@@ -479,13 +521,13 @@ cdef class nmod_poly(flint_poly):
479521        return  P, Q
480522
481523    def  __divmod__  (s , t ):
482-         t =  any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
524+         t =  s. any_as_nmod_poly(t)
483525        if  t is  NotImplemented :
484526            return  t
485527        return  s._divmod_(t)
486528
487529    def  __rdivmod__  (s , t ):
488-         t =  any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
530+         t =  s. any_as_nmod_poly(t)
489531        if  t is  NotImplemented :
490532            return  t
491533        return  t._divmod_(s)
@@ -534,7 +576,7 @@ cdef class nmod_poly(flint_poly):
534576        if  e <  0 :
535577            raise  ValueError (" Exponent must be non-negative"  )
536578
537-         modulus =  any_as_nmod_poly(modulus, ( < nmod_poly > self ).val.mod )
579+         modulus =  self .ctx.any_as_nmod_poly(modulus )
538580        if  modulus is  NotImplemented :
539581            raise  TypeError (" cannot convert input to nmod_poly"  )
540582
@@ -556,7 +598,7 @@ cdef class nmod_poly(flint_poly):
556598
557599        #  To optimise powering, we precompute the inverse of the reverse of the modulus
558600        if  mod_rev_inv is  not  None :
559-             mod_rev_inv =  any_as_nmod_poly(mod_rev_inv, ( < nmod_poly > self ).val.mod )
601+             mod_rev_inv =  self . any_as_nmod_poly(mod_rev_inv)
560602            if  mod_rev_inv is  NotImplemented :
561603                raise  TypeError (f" Cannot interpret {mod_rev_inv} as a polynomial"  )
562604        else :
@@ -585,7 +627,7 @@ cdef class nmod_poly(flint_poly):
585627
586628        """  
587629        cdef nmod_poly res
588-         other =  any_as_nmod_poly(other, ( < nmod_poly > self ).val.mod )
630+         other =  self . any_as_nmod_poly(other)
589631        if  other is  NotImplemented :
590632            raise  TypeError (" cannot convert input to nmod_poly"  )
591633        if  self .val.mod.n !=  (< nmod_poly> other).val.mod.n:
@@ -597,7 +639,7 @@ cdef class nmod_poly(flint_poly):
597639
598640    def  xgcd (self , other ):
599641        cdef nmod_poly res1, res2, res3
600-         other =  any_as_nmod_poly(other, ( < nmod_poly > self ).val.mod )
642+         other =  self . any_as_nmod_poly(other)
601643        if  other is  NotImplemented :
602644            raise  TypeError (" cannot convert input to fmpq_poly"  )
603645        res1 =  nmod_poly.__new__ (nmod_poly)
0 commit comments