Skip to content

Commit b7b9ddd

Browse files
committed
Relax check_registration to allow subsubclasses
This is necessary to allow __new__ in a Model subclass to return subclasses of the subclass.
1 parent 8c4fb80 commit b7b9ddd

File tree

3 files changed

+43
-9
lines changed

3 files changed

+43
-9
lines changed

postgres/__init__.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@
168168
from collections import namedtuple
169169

170170
import psycopg2
171+
from inspect import isclass
171172
from postgres.context_managers import ConnectionContextManager
172173
from postgres.context_managers import CursorContextManager
173174
from postgres.cursors import SimpleTupleCursor, SimpleNamedTupleCursor
@@ -214,8 +215,8 @@ def __str__(self):
214215
class NotAModel(Exception):
215216
def __str__(self):
216217
return "Only subclasses of postgres.orm.Model can be registered as " \
217-
"orm models. {} (registered for {}) doesn't fit the bill." \
218-
.format(self.args[0].__name__, self.args[1])
218+
"orm models. {} doesn't fit the bill." \
219+
.format(self.args[0])
219220

220221
class NoTypeSpecified(Exception):
221222
def __str__(self):
@@ -621,8 +622,7 @@ def register_model(self, ModelSubclass, typname=None):
621622
subclassing :py:class:`~postgres.orm.Model`.
622623
623624
"""
624-
if not issubclass(ModelSubclass, Model):
625-
raise NotAModel(ModelSubclass)
625+
self._validate_model_subclass(ModelSubclass)
626626

627627
if typname is None:
628628
typname = getattr(ModelSubclass, 'typname', None)
@@ -676,11 +676,14 @@ def unregister_model(self, ModelSubclass):
676676
del self.model_registry[key]
677677

678678

679-
def check_registration(self, ModelSubclass):
679+
def check_registration(self, ModelSubclass, include_subsubclasses=False):
680680
"""Check whether an ORM model is registered.
681681
682682
:param ModelSubclass: the :py:class:`~postgres.orm.Model` subclass to
683683
check for
684+
:param bool include_subsubclasses: whether to also check for subclasses
685+
of :py:class:`ModelSubclass` or just :py:class:`ModelSubclass`
686+
itself
684687
685688
:returns: the :py:attr:`typname` (a string) for which this model is
686689
registered, or a list of strings if it's registered for multiple
@@ -690,7 +693,13 @@ def check_registration(self, ModelSubclass):
690693
:raises: :py:exc:`~postgres.NotRegistered`
691694
692695
"""
693-
keys = [k for k,v in self.model_registry.items() if v is ModelSubclass]
696+
self._validate_model_subclass(ModelSubclass)
697+
698+
if include_subsubclasses:
699+
filt = lambda v: v is ModelSubclass or issubclass(ModelSubclass, v)
700+
else:
701+
filt = lambda v: v is ModelSubclass
702+
keys = [k for k,v in self.model_registry.items() if filt(v)]
694703
if not keys:
695704
raise NotRegistered(ModelSubclass)
696705
if len(keys) == 1:
@@ -700,6 +709,11 @@ def check_registration(self, ModelSubclass):
700709
return keys
701710

702711

712+
def _validate_model_subclass(self, ModelSubclass, ):
713+
if not isclass(ModelSubclass) or not issubclass(ModelSubclass, Model):
714+
raise NotAModel(ModelSubclass)
715+
716+
703717
# Class Factories
704718
# ===============
705719

postgres/orm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ class Model(object):
223223
def __init__(self, record):
224224
if self.db is None:
225225
raise NotBound(self)
226-
self.db.check_registration(self.__class__)
226+
self.db.check_registration(self.__class__, include_subsubclasses=True)
227227
self.__read_only_attributes = record.keys()
228228
self.set_attributes(**record)
229229

tests.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections import namedtuple
55
from unittest import TestCase
66

7-
from postgres import Postgres, NotRegistered
7+
from postgres import Postgres, NotAModel, NotRegistered
88
from postgres.cursors import TooFew, TooMany, SimpleDictCursor
99
from postgres.orm import ReadOnly, Model
1010
from psycopg2 import InterfaceError, ProgrammingError
@@ -270,7 +270,27 @@ def assign():
270270

271271
def test_check_register_raises_if_passed_a_model_instance(self):
272272
obj = self.MyModel({'bar': 'baz'})
273-
raises(NotRegistered, self.db.check_registration, obj)
273+
raises(NotAModel, self.db.check_registration, obj)
274+
275+
def test_check_register_doesnt_include_subsubclasses(self):
276+
class Other(self.MyModel): pass
277+
raises(NotRegistered, self.db.check_registration, Other)
278+
279+
def test_dot_dot_dot_unless_you_ask_it_to(self):
280+
class Other(self.MyModel): pass
281+
assert self.db.check_registration(Other, True) == 'foo'
282+
283+
def test_check_register_handles_complex_cases(self):
284+
self.installFlah()
285+
286+
class Second(Model): pass
287+
self.db.run("CREATE TABLE blum (bar text)")
288+
self.db.register_model(Second, 'blum')
289+
assert self.db.check_registration(Second) == 'blum'
290+
291+
class Third(self.MyModel, Second): pass
292+
actual = list(sorted(self.db.check_registration(Third, True)))
293+
assert actual == ['blum', 'flah', 'foo']
274294

275295
def test_a_model_can_be_used_for_a_second_type(self):
276296
self.installFlah()

0 commit comments

Comments
 (0)