@@ -599,11 +599,17 @@ def get_connection(self):
599
599
return ConnectionContextManager (self .pool )
600
600
601
601
602
- def register_model (self , ModelSubclass ):
602
+ def register_model (self , ModelSubclass , typname = None ):
603
603
"""Register an ORM model.
604
604
605
605
:param ModelSubclass: the :py:class:`~postgres.orm.Model` subclass to
606
606
register with this :py:class:`~postgres.Postgres` instance
607
+
608
+ :param typname: a string indicating the Postgres type to register this
609
+ model for (``typname``, without an "e," is the name of the relevant
610
+ column in the underlying ``pg_type`` table). If :py:class:`None`,
611
+ we'll look for :py:attr:`ModelSubclass.typname`.
612
+
607
613
:raises: :py:exc:`~postgres.NotAModel`,
608
614
:py:exc:`~postgres.NoTypeSpecified`,
609
615
:py:exc:`~postgres.NoSuchType`,
@@ -618,28 +624,30 @@ def register_model(self, ModelSubclass):
618
624
if not issubclass (ModelSubclass , Model ):
619
625
raise NotAModel (ModelSubclass )
620
626
621
- if getattr (ModelSubclass , 'typname' , None ) is None :
622
- raise NoTypeSpecified (ModelSubclass )
627
+ if typname is None :
628
+ typname = getattr (ModelSubclass , 'typname' , None )
629
+ if typname is None :
630
+ raise NoTypeSpecified (ModelSubclass )
623
631
624
632
n = self .one ( "SELECT count(*) FROM pg_type WHERE typname=%s"
625
- , (ModelSubclass . typname ,)
633
+ , (typname ,)
626
634
)
627
635
if n < 1 :
628
636
# Could be more than one since we don't constrain by typnamespace.
629
637
# XXX What happens then?
630
- raise NoSuchType (ModelSubclass . typname )
638
+ raise NoSuchType (typname )
631
639
632
- if ModelSubclass . typname in self .model_registry :
633
- existing_model = self .model_registry [ModelSubclass . typname ]
634
- raise AlreadyRegistered (existing_model , ModelSubclass . typname )
640
+ if typname in self .model_registry :
641
+ existing_model = self .model_registry [typname ]
642
+ raise AlreadyRegistered (existing_model , typname )
635
643
636
- self .model_registry [ModelSubclass . typname ] = ModelSubclass
644
+ self .model_registry [typname ] = ModelSubclass
637
645
ModelSubclass .db = self
638
646
639
- # register a composite (but don't use RealDictCursor, not sure why)
647
+ # register a composite
640
648
with self .get_connection () as conn :
641
649
cursor = conn .cursor ()
642
- name = ModelSubclass . typname
650
+ name = typname
643
651
if sys .version_info [0 ] < 3 :
644
652
name = name .encode ('UTF-8' )
645
653
register_composite ( name
@@ -656,30 +664,40 @@ def unregister_model(self, ModelSubclass):
656
664
unregister
657
665
:raises: :py:exc:`~postgres.NotRegistered`
658
666
667
+ If :py:class:`ModelSubclass` is registered for multiple types, it is
668
+ unregistered for all of them.
669
+
659
670
"""
660
- key = self .check_registration (ModelSubclass )
661
- del self .model_registry [key ]
671
+ keys = self .check_registration (ModelSubclass )
672
+ if not isinstance (keys , list ):
673
+ # Wrap single string in a list. Flip-side of XXX just below.
674
+ keys = [keys ]
675
+ for key in keys :
676
+ del self .model_registry [key ]
662
677
663
678
664
679
def check_registration (self , ModelSubclass ):
665
680
"""Check whether an ORM model is registered.
666
681
667
682
:param ModelSubclass: the :py:class:`~postgres.orm.Model` subclass to
668
683
check for
684
+
669
685
:returns: the :py:attr:`typname` (a string) for which this model is
670
- registered
686
+ registered, or a list of strings if it's registered for multiple
687
+ types
688
+
671
689
:rettype: string
672
690
:raises: :py:exc:`~postgres.NotRegistered`
673
691
674
692
"""
675
- key = None
676
- for k , v in self .model_registry .items ():
677
- if v is ModelSubclass :
678
- key = k
679
- break
680
- if key is None :
693
+ keys = [k for k ,v in self .model_registry .items () if v is ModelSubclass ]
694
+ if not keys :
681
695
raise NotRegistered (ModelSubclass )
682
- return key
696
+ if len (keys ) == 1 :
697
+ # Dereference a single-item list, for backwards-compatibility.
698
+ # XXX If/when we go to 3.0, lose this cruft (always return list).
699
+ keys = keys [0 ]
700
+ return keys
683
701
684
702
685
703
# Class Factories
0 commit comments