168
168
from collections import namedtuple
169
169
170
170
import psycopg2
171
+ from inspect import isclass
171
172
from postgres .context_managers import ConnectionContextManager
172
173
from postgres .context_managers import CursorContextManager
173
174
from postgres .cursors import SimpleTupleCursor , SimpleNamedTupleCursor
@@ -214,8 +215,8 @@ def __str__(self):
214
215
class NotAModel (Exception ):
215
216
def __str__ (self ):
216
217
return "Only subclasses of postgres.orm.Model can be registered as " \
217
- "orm models. {} (registered for {}) doesn't fit the bill." \
218
- .format (self .args [0 ]. __name__ , self . args [ 1 ] )
218
+ "orm models. {} doesn't fit the bill." \
219
+ .format (self .args [0 ])
219
220
220
221
class NoTypeSpecified (Exception ):
221
222
def __str__ (self ):
@@ -621,8 +622,7 @@ def register_model(self, ModelSubclass, typname=None):
621
622
subclassing :py:class:`~postgres.orm.Model`.
622
623
623
624
"""
624
- if not issubclass (ModelSubclass , Model ):
625
- raise NotAModel (ModelSubclass )
625
+ self ._validate_model_subclass (ModelSubclass )
626
626
627
627
if typname is None :
628
628
typname = getattr (ModelSubclass , 'typname' , None )
@@ -676,11 +676,14 @@ def unregister_model(self, ModelSubclass):
676
676
del self .model_registry [key ]
677
677
678
678
679
- def check_registration (self , ModelSubclass ):
679
+ def check_registration (self , ModelSubclass , include_subsubclasses = False ):
680
680
"""Check whether an ORM model is registered.
681
681
682
682
:param ModelSubclass: the :py:class:`~postgres.orm.Model` subclass to
683
683
check for
684
+ :param bool include_subsubclasses: whether to also check for subclasses
685
+ of :py:class:`ModelSubclass` or just :py:class:`ModelSubclass`
686
+ itself
684
687
685
688
:returns: the :py:attr:`typname` (a string) for which this model is
686
689
registered, or a list of strings if it's registered for multiple
@@ -690,7 +693,13 @@ def check_registration(self, ModelSubclass):
690
693
:raises: :py:exc:`~postgres.NotRegistered`
691
694
692
695
"""
693
- keys = [k for k ,v in self .model_registry .items () if v is ModelSubclass ]
696
+ self ._validate_model_subclass (ModelSubclass )
697
+
698
+ if include_subsubclasses :
699
+ filt = lambda v : v is ModelSubclass or issubclass (ModelSubclass , v )
700
+ else :
701
+ filt = lambda v : v is ModelSubclass
702
+ keys = [k for k ,v in self .model_registry .items () if filt (v )]
694
703
if not keys :
695
704
raise NotRegistered (ModelSubclass )
696
705
if len (keys ) == 1 :
@@ -700,6 +709,11 @@ def check_registration(self, ModelSubclass):
700
709
return keys
701
710
702
711
712
+ def _validate_model_subclass (self , ModelSubclass , ):
713
+ if not isclass (ModelSubclass ) or not issubclass (ModelSubclass , Model ):
714
+ raise NotAModel (ModelSubclass )
715
+
716
+
703
717
# Class Factories
704
718
# ===============
705
719
0 commit comments