Skip to content

Commit 76323ab

Browse files
Alex HillAlexHill
authored andcommitted
Support using L() inside complex expressions in predicate
1 parent 974fd24 commit 76323ab

File tree

4 files changed

+48
-8
lines changed

4 files changed

+48
-8
lines changed

relativity/fields.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import unicode_literals, absolute_import
22

3+
import copy
34
from collections import OrderedDict
45

56
import django
67
from django.db import models, connections
7-
from django.db.models import F, ForeignObject
8+
from django.db.models import F, ForeignObject, Value
89
from django.db.models.fields.related_descriptors import (
910
ReverseManyToOneDescriptor,
1011
ReverseOneToOneDescriptor,
@@ -237,19 +238,28 @@ def get_extra_restriction(self, where_class, alias, related_alias):
237238
where_class=where_class,
238239
)
239240

241+
@classmethod
242+
def _resolve_expression_local_references(cls, expr, obj):
243+
if isinstance(expr, L):
244+
return expr._relativity_resolve_for_instance(obj)
245+
else:
246+
for source_expr in expr.get_source_expressions():
247+
cls._resolve_expression_local_references(source_expr, obj)
248+
return expr
249+
240250
def get_forward_related_filter(self, obj):
241251
"""
242252
Return the filter arguments which select the instances of self.model
243253
that are related to obj.
244254
"""
245255
q = self.field.predicate
246-
q = q() if callable(q) else q
256+
q = q() if callable(q) else copy.deepcopy(q)
247257

248258
# If this is a simple restriction that can be expressed as an AND of
249259
# two basic field lookups, we can return a dictionary of filters...
250260
if q.connector == Q.AND and all(type(c) == tuple for c in q.children):
251261
return {
252-
lookup: getattr(obj, v.name) if isinstance(v, L) else v
262+
lookup: self._resolve_expression_local_references(v, obj)
253263
for lookup, v in q.children
254264
}
255265

@@ -371,6 +381,11 @@ def relationship_related_query_name(self):
371381

372382

373383
class L(F):
384+
def _relativity_resolve_for_instance(self, obj):
385+
val = getattr(obj, self.name)
386+
self._relativity_resolved_value = Value(val)
387+
return val
388+
374389
def resolve_expression(
375390
self,
376391
query=None,
@@ -380,7 +395,10 @@ def resolve_expression(
380395
for_save=False,
381396
simple_col=False,
382397
):
383-
# noinspection PyProtectedMember
384-
return super(L, self).resolve_expression(
385-
query._relationship_field_query, allow_joins, reuse, summarize, for_save
386-
)
398+
if hasattr(self, "_relativity_resolved_value"):
399+
return self._relativity_resolved_value
400+
else:
401+
# noinspection PyProtectedMember
402+
return super(L, self).resolve_expression(
403+
query._relationship_field_query, allow_joins, reuse, summarize, for_save
404+
)

tests/models.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import absolute_import, unicode_literals
22

33
from django.db import models
4-
from django.db.models import Lookup, Q
4+
from django.db.models import Lookup, Q, Value
55
from django.db.models.fields import Field
6+
from django.db.models.functions import Concat
67
from mptt.fields import TreeForeignKey
78
from mptt.models import MPTTModel
89
from six import python_2_unicode_compatible
@@ -186,3 +187,17 @@ class SavedFilter(models.Model):
186187
user = models.ForeignKey(User, on_delete=models.CASCADE)
187188
search_regex = models.TextField()
188189
chemicals = Relationship(Chemical, Q(formula__regex=L("search_regex")))
190+
191+
192+
class UserGenerator(models.Model):
193+
194+
user = Relationship(
195+
User,
196+
Q(username=Concat(Value("generated_for_"), L("id"))),
197+
multiple=False,
198+
reverse_multiple=False,
199+
)
200+
201+
def save(self, *args, **kwargs):
202+
super(UserGenerator, self).save(*args, **kwargs)
203+
User.objects.create(username="generated_for_%d" % self.id)

tests/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,5 @@
1717
TEST_NON_SERIALIZED_APPS = ["tests"]
1818

1919
DEBUG = True
20+
21+
DEFAULT_AUTO_FIELD = "django.db.models.AutoField"

tests/tests.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
SavedFilter,
1818
User,
1919
LinkedNode,
20+
UserGenerator,
2021
)
2122

2223

@@ -354,3 +355,7 @@ def test_single_reverse(self):
354355
)
355356
self.assertEqual(node_3.prev, node_2)
356357
self.assertEqual(node_1.next, node_2)
358+
359+
def test_complex_expression(self):
360+
ug = UserGenerator.objects.create()
361+
self.assertEqual(ug.user, User.objects.get(username="generated_for_%d" % ug.id))

0 commit comments

Comments
 (0)