Skip to content

Commit 705acf4

Browse files
committed
polymorphic accessors now use builtin caching from underlying fields
1 parent bd5faf0 commit 705acf4

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

polymorphic/models.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,15 @@ def __init__(self, *args, **kwargs):
197197
return
198198
self.__class__.polymorphic_super_sub_accessors_replaced = True
199199

200-
def create_accessor_function_for_model(model, accessor_name):
200+
def create_accessor_function_for_model(model, field):
201201
def accessor_function(self):
202-
objects = getattr(model, "_base_objects", model.objects)
203-
attr = objects.get(pk=self.pk)
204-
return attr
202+
try:
203+
rel_obj = field.get_cached_value(self)
204+
except KeyError:
205+
objects = getattr(model, "_base_objects", model.objects)
206+
rel_obj = objects.get(pk=self.pk)
207+
field.set_cached_value(self, rel_obj)
208+
return rel_obj
205209

206210
return accessor_function
207211

@@ -214,10 +218,14 @@ def accessor_function(self):
214218
type(orig_accessor),
215219
(ReverseOneToOneDescriptor, ForwardManyToOneDescriptor),
216220
):
221+
222+
field = orig_accessor.related \
223+
if isinstance(orig_accessor, ReverseOneToOneDescriptor) else orig_accessor.field
224+
217225
setattr(
218226
self.__class__,
219227
name,
220-
property(create_accessor_function_for_model(model, name)),
228+
property(create_accessor_function_for_model(model, field)),
221229
)
222230

223231
def _get_inheritance_relation_fields_and_models(self):

polymorphic/tests/test_orm.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,29 @@ def test_parent_link_and_related_name(self):
985985
# test that we can delete the object
986986
t.delete()
987987

988+
def test_polymorphic__accessor_caching(self):
989+
blog_a = BlogA.objects.create(name="blog")
990+
991+
blog_base = BlogBase.objects.non_polymorphic().get(id=blog_a.id)
992+
blog_a = BlogA.objects.get(id=blog_a.id)
993+
994+
# test reverse accessor & check that we get back cached object on repeated access
995+
self.assertEqual(blog_base.bloga, blog_a)
996+
self.assertIs(blog_base.bloga, blog_base.bloga)
997+
cached_blog_a = blog_base.bloga
998+
999+
# test forward accessor & check that we get back cached object on repeated access
1000+
self.assertEqual(blog_a.blogbase_ptr, blog_base)
1001+
self.assertIs(blog_a.blogbase_ptr, blog_a.blogbase_ptr)
1002+
cached_blog_base = blog_a.blogbase_ptr
1003+
1004+
# check that refresh_from_db correctly clears cached related objects
1005+
blog_base.refresh_from_db()
1006+
blog_a.refresh_from_db()
1007+
1008+
self.assertIsNot(cached_blog_a, blog_base.bloga)
1009+
self.assertIsNot(cached_blog_base, blog_a.blogbase_ptr)
1010+
9881011
def test_polymorphic__aggregate(self):
9891012
"""test ModelX___field syntax on aggregate (should work for annotate either)"""
9901013

0 commit comments

Comments
 (0)