Skip to content

Commit 2fae8d3

Browse files
committed
Enable models to be registered for multiple types
This addresses the underlying concern on #22, though in a different way than is suggested there. The common case is registering for a single type, and that is what MyModel.typname is intended for. If you want to get fancy then use register_composite with a typname kwarg.
1 parent 775e316 commit 2fae8d3

File tree

2 files changed

+75
-25
lines changed

2 files changed

+75
-25
lines changed

postgres/__init__.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -599,11 +599,17 @@ def get_connection(self):
599599
return ConnectionContextManager(self.pool)
600600

601601

602-
def register_model(self, ModelSubclass):
602+
def register_model(self, ModelSubclass, typname=None):
603603
"""Register an ORM model.
604604
605605
:param ModelSubclass: the :py:class:`~postgres.orm.Model` subclass to
606606
register with this :py:class:`~postgres.Postgres` instance
607+
608+
:param typname: a string indicating the Postgres type to register this
609+
model for (``typname``, without an "e," is the name of the relevant
610+
column in the underlying ``pg_type`` table). If :py:class:`None`,
611+
we'll look for :py:attr:`ModelSubclass.typname`.
612+
607613
:raises: :py:exc:`~postgres.NotAModel`,
608614
:py:exc:`~postgres.NoTypeSpecified`,
609615
:py:exc:`~postgres.NoSuchType`,
@@ -618,28 +624,30 @@ def register_model(self, ModelSubclass):
618624
if not issubclass(ModelSubclass, Model):
619625
raise NotAModel(ModelSubclass)
620626

621-
if getattr(ModelSubclass, 'typname', None) is None:
622-
raise NoTypeSpecified(ModelSubclass)
627+
if typname is None:
628+
typname = getattr(ModelSubclass, 'typname', None)
629+
if typname is None:
630+
raise NoTypeSpecified(ModelSubclass)
623631

624632
n = self.one( "SELECT count(*) FROM pg_type WHERE typname=%s"
625-
, (ModelSubclass.typname,)
633+
, (typname,)
626634
)
627635
if n < 1:
628636
# Could be more than one since we don't constrain by typnamespace.
629637
# XXX What happens then?
630-
raise NoSuchType(ModelSubclass.typname)
638+
raise NoSuchType(typname)
631639

632-
if ModelSubclass.typname in self.model_registry:
633-
existing_model = self.model_registry[ModelSubclass.typname]
634-
raise AlreadyRegistered(existing_model, ModelSubclass.typname)
640+
if typname in self.model_registry:
641+
existing_model = self.model_registry[typname]
642+
raise AlreadyRegistered(existing_model, typname)
635643

636-
self.model_registry[ModelSubclass.typname] = ModelSubclass
644+
self.model_registry[typname] = ModelSubclass
637645
ModelSubclass.db = self
638646

639-
# register a composite (but don't use RealDictCursor, not sure why)
647+
# register a composite
640648
with self.get_connection() as conn:
641649
cursor = conn.cursor()
642-
name = ModelSubclass.typname
650+
name = typname
643651
if sys.version_info[0] < 3:
644652
name = name.encode('UTF-8')
645653
register_composite( name
@@ -656,30 +664,40 @@ def unregister_model(self, ModelSubclass):
656664
unregister
657665
:raises: :py:exc:`~postgres.NotRegistered`
658666
667+
If :py:class:`ModelSubclass` is registered for multiple types, it is
668+
unregistered for all of them.
669+
659670
"""
660-
key = self.check_registration(ModelSubclass)
661-
del self.model_registry[key]
671+
keys = self.check_registration(ModelSubclass)
672+
if not isinstance(keys, list):
673+
# Wrap single string in a list. Flip-side of XXX just below.
674+
keys = [keys]
675+
for key in keys:
676+
del self.model_registry[key]
662677

663678

664679
def check_registration(self, ModelSubclass):
665680
"""Check whether an ORM model is registered.
666681
667682
:param ModelSubclass: the :py:class:`~postgres.orm.Model` subclass to
668683
check for
684+
669685
:returns: the :py:attr:`typname` (a string) for which this model is
670-
registered
686+
registered, or a list of strings if it's registered for multiple
687+
types
688+
671689
:rettype: string
672690
:raises: :py:exc:`~postgres.NotRegistered`
673691
674692
"""
675-
key = None
676-
for k, v in self.model_registry.items():
677-
if v is ModelSubclass:
678-
key = k
679-
break
680-
if key is None:
693+
keys = [k for k,v in self.model_registry.items() if v is ModelSubclass]
694+
if not keys:
681695
raise NotRegistered(ModelSubclass)
682-
return key
696+
if len(keys) == 1:
697+
# Dereference a single-item list, for backwards-compatibility.
698+
# XXX If/when we go to 3.0, lose this cruft (always return list).
699+
keys = keys[0]
700+
return keys
683701

684702

685703
# Class Factories

tests.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from postgres import Postgres, NotRegistered
88
from postgres.cursors import TooFew, TooMany, SimpleDictCursor
9-
from postgres.orm import ReadOnly
9+
from postgres.orm import ReadOnly, Model
1010
from psycopg2 import InterfaceError, ProgrammingError
1111
from pytest import raises
1212

@@ -223,14 +223,11 @@ def test_get_connection_gets_a_connection(self):
223223

224224
class TestORM(WithData):
225225

226-
from postgres.orm import Model
227-
228226
class MyModel(Model):
229227

230228
typname = "foo"
231229

232230
def __init__(self, record):
233-
from postgres.orm import Model
234231
Model.__init__(self, record)
235232
self.bar_from_init = record['bar']
236233

@@ -247,6 +244,10 @@ def setUp(self):
247244
def tearDown(self):
248245
self.db.model_registry = {}
249246

247+
def installFlah(self):
248+
self.db.run("CREATE TABLE flah (bar text)")
249+
self.db.register_model(self.MyModel, 'flah')
250+
250251
def test_orm_basically_works(self):
251252
one = self.db.one("SELECT foo.*::foo FROM foo WHERE bar='baz'")
252253
assert one.__class__ == self.MyModel
@@ -271,6 +272,37 @@ def test_check_register_raises_if_passed_a_model_instance(self):
271272
obj = self.MyModel({'bar': 'baz'})
272273
raises(NotRegistered, self.db.check_registration, obj)
273274

275+
def test_a_model_can_be_used_for_a_second_type(self):
276+
self.installFlah()
277+
self.db.run("INSERT INTO flah VALUES ('double')")
278+
self.db.run("INSERT INTO flah VALUES ('trouble')")
279+
flah = self.db.one("SELECT flah.*::flah FROM flah WHERE bar='double'")
280+
assert flah.bar == "double"
281+
282+
def test_check_register_returns_string_for_single(self):
283+
assert self.db.check_registration(self.MyModel) == 'foo'
284+
285+
def test_check_register_returns_list_for_multiple(self):
286+
self.installFlah()
287+
actual = list(sorted(self.db.check_registration(self.MyModel)))
288+
assert actual == ['flah', 'foo']
289+
290+
def test_unregister_unregisters_one(self):
291+
self.db.unregister_model(self.MyModel)
292+
assert self.db.model_registry == {}
293+
294+
def test_unregister_leaves_other(self):
295+
self.db.run("CREATE TABLE flum (bar text)")
296+
class OtherModel(Model): pass
297+
self.db.register_model(OtherModel, 'flum')
298+
self.db.unregister_model(self.MyModel)
299+
assert self.db.model_registry == {'flum': OtherModel}
300+
301+
def test_unregister_unregisters_multiple(self):
302+
self.installFlah()
303+
self.db.unregister_model(self.MyModel)
304+
assert self.db.model_registry == {}
305+
274306

275307
# cursor_factory
276308
# ==============

0 commit comments

Comments
 (0)