Skip to content

Commit 775e316

Browse files
committed
Fix regression with Model instantiation
We need to pass the class, not an instance, to check_registration, but I neglected to address that when moving the check_registration call inside of Model.__init__. This bug was masked by a bug in check_registration itself (variable clobbering in a loop).
1 parent ad54b75 commit 775e316

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

postgres/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,8 +673,9 @@ def check_registration(self, ModelSubclass):
673673
674674
"""
675675
key = None
676-
for key, v in self.model_registry.items():
676+
for k, v in self.model_registry.items():
677677
if v is ModelSubclass:
678+
key = k
678679
break
679680
if key is None:
680681
raise NotRegistered(ModelSubclass)

postgres/orm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ class Model(object):
223223
def __init__(self, record):
224224
if self.db is None:
225225
raise NotBound(self)
226-
self.db.check_registration(self)
226+
self.db.check_registration(self.__class__)
227227
self.__read_only_attributes = record.keys()
228228
self.set_attributes(**record)
229229

tests.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
from collections import namedtuple
55
from unittest import TestCase
66

7-
from postgres import Postgres
7+
from postgres import Postgres, NotRegistered
88
from postgres.cursors import TooFew, TooMany, SimpleDictCursor
99
from postgres.orm import ReadOnly
1010
from psycopg2 import InterfaceError, ProgrammingError
11+
from pytest import raises
1112

1213

1314
DATABASE_URL = os.environ['DATABASE_URL']
@@ -266,6 +267,10 @@ def assign():
266267
one.bar = "blah"
267268
self.assertRaises(ReadOnly, assign)
268269

270+
def test_check_register_raises_if_passed_a_model_instance(self):
271+
obj = self.MyModel({'bar': 'baz'})
272+
raises(NotRegistered, self.db.check_registration, obj)
273+
269274

270275
# cursor_factory
271276
# ==============

0 commit comments

Comments
 (0)