Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions polymorphic/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def __init__(self, *args, **kwargs):
# to that queryset as well).
self.polymorphic_deferred_loading = (set(), True)

self._polymorphic_select_related = {}
self._polymorphic_prefetch_related = {}
self._polymorphic_custom_queryset = {}

def _clone(self, *args, **kwargs):
# Django's _clone only copies its own variables, so we need to copy ours here
new = super()._clone(*args, **kwargs)
Expand All @@ -120,6 +124,9 @@ def _clone(self, *args, **kwargs):
copy.copy(self.polymorphic_deferred_loading[0]),
self.polymorphic_deferred_loading[1],
)
new._polymorphic_select_related = copy.copy(self._polymorphic_select_related)
new._polymorphic_prefetch_related = copy.copy(self._polymorphic_prefetch_related)
new._polymorphic_custom_queryset = copy.copy(self._polymorphic_custom_queryset)
return new

def as_manager(cls):
Expand Down Expand Up @@ -417,12 +424,30 @@ class self.model, but as a class derived from self.model. We want to re-fetch
# TODO: defer(), only(): support for these would be around here
for real_concrete_class, idlist in idlist_per_model.items():
indices = indexlist_per_model[real_concrete_class]
real_objects = real_concrete_class._base_objects.db_manager(self.db).filter(
if self._polymorphic_custom_queryset.get(real_concrete_class):
real_objects = self._polymorphic_custom_queryset[real_concrete_class]
else:
real_objects = real_concrete_class._base_objects.db_manager(self.db)

real_objects = real_objects.filter(
**{("%s__in" % pk_name): idlist}
)
# copy select related configuration to new qs

# copy select_related() fields from base objects to real objects
real_objects.query.select_related = self.query.select_related

# polymorphic select_related() fields if any
if real_concrete_class in self._polymorphic_select_related:
real_objects = real_objects.select_related(
*self._polymorphic_select_related[real_concrete_class]
)

# polymorphic prefetch related configuration to new qs
if real_concrete_class in self._polymorphic_prefetch_related:
real_objects = real_objects.prefetch_related(
*self._polymorphic_prefetch_related[real_concrete_class]
)

# Copy deferred fields configuration to the new queryset
deferred_loading_fields = []
existing_fields = self.polymorphic_deferred_loading[0]
Expand Down Expand Up @@ -535,3 +560,22 @@ def get_real_instances(self, base_result_objects=None):
return olist
clist = PolymorphicQuerySet._p_list_class(olist)
return clist

def select_polymorphic_related(self, polymorphic_subclass, *fields):
if self.query.select_related is True:
raise ValueError(
"select_polymorphic_related() cannot be used together with select_related=True"
)
clone = self._clone()
clone._polymorphic_select_related[polymorphic_subclass] = fields
return clone

def prefetch_polymorphic_related(self, polymorphic_subclass, *lookups):
clone = self._clone()
clone._polymorphic_prefetch_related[polymorphic_subclass] = lookups
return clone

def custom_queryset(self, polymorphic_subclass, queryset):
clone = self._clone()
clone._polymorphic_custom_queryset[polymorphic_subclass] = queryset
return clone