Skip to content

Commit 210401e

Browse files
committed
Support lookup from related collections (#1)
Add support to select_related.
1 parent 72a20d4 commit 210401e

File tree

9 files changed

+168
-12
lines changed

9 files changed

+168
-12
lines changed

.github/workflows/test-python.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ jobs:
9393
ordering
9494
or_lookups
9595
queries.tests.Ticket12807Tests.test_ticket_12807
96+
select_related
9697
sessions_tests
9798
timezones
9899
update

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ Migrations for 'admin':
114114
- `datetimes()`
115115
- `distinct()`
116116
- `extra()`
117-
- `select_related()`
118117

119118
- `Subquery`, `Exists`, and using a `QuerySet` in `QuerySet.annotate()` aren't
120119
supported.

django_mongodb/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
check_django_compatability()
88

9+
from .datastructures import register_datastructures # noqa: E402
910
from .expressions import register_expressions # noqa: E402
1011
from .fields import register_fields # noqa: E402
1112
from .functions import register_functions # noqa: E402
1213
from .lookups import register_lookups # noqa: E402
1314
from .query import register_nodes # noqa: E402
1415

16+
register_datastructures()
1517
register_expressions()
1618
register_fields()
1719
register_functions()

django_mongodb/compiler.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,33 @@ def results_iter(
4444
"""
4545
columns = self.get_columns()
4646

47+
related_columns = []
4748
if results is None:
4849
# QuerySet.values() or values_list()
4950
try:
5051
results = self.build_query(columns).fetch()
5152
except EmptyResultSet:
5253
results = []
54+
else:
55+
index = len(columns)
56+
while index < self.col_count:
57+
foreign_columns = []
58+
foreign_relation = self.select[index][0].alias
59+
while index < self.col_count and foreign_relation == self.select[index][0].alias:
60+
foreign_columns.append(self.select[index][0])
61+
index += 1
62+
related_columns.append(
63+
(
64+
foreign_relation,
65+
[(column.target.column, column) for column in foreign_columns],
66+
)
67+
)
5368

5469
converters = self.get_converters(columns)
5570
for entity in results:
56-
yield self._make_result(entity, columns, converters, tuple_expected=tuple_expected)
71+
yield self._make_result(
72+
entity, columns, related_columns, converters, tuple_expected=tuple_expected
73+
)
5774

5875
def has_results(self):
5976
return bool(self.get_count(check_exists=True))
@@ -72,13 +89,22 @@ def get_converters(self, expressions):
7289
converters[name] = backend_converters + field_converters
7390
return converters
7491

75-
def _make_result(self, entity, columns, converters, tuple_expected=False):
92+
def _make_result(self, entity, columns, related_columns, converters, tuple_expected=False):
7693
"""
7794
Decode values for the given fields from the database entity.
7895
7996
The entity is assumed to be a dict using field database column
8097
names as keys.
8198
"""
99+
result = self._project_result(entity, columns, converters, tuple_expected)
100+
# Related columns
101+
for relation, columns in related_columns:
102+
result += self._project_result(entity[relation], columns, converters, tuple_expected)
103+
if tuple_expected:
104+
result = tuple(result)
105+
return result
106+
107+
def _project_result(self, entity, columns, converters, tuple_expected=False):
82108
result = []
83109
for name, col in columns:
84110
field = col.field
@@ -90,8 +116,6 @@ def _make_result(self, entity, columns, converters, tuple_expected=False):
90116
for converter in converters.get(name, ()):
91117
value = converter(value, col, self.connection)
92118
result.append(value)
93-
if tuple_expected:
94-
result = tuple(result)
95119
return result
96120

97121
def check_query(self):
@@ -111,9 +135,11 @@ def check_query(self):
111135
if self.query.extra:
112136
raise NotSupportedError("QuerySet.extra() is not supported on MongoDB.")
113137
if self.query.select_related:
114-
raise NotSupportedError("QuerySet.select_related() is not supported on MongoDB.")
138+
pass
139+
# raise NotSupportedError("QuerySet.select_related() is not supported on MongoDB.")
115140
if len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) > 1:
116-
raise NotSupportedError("Queries with multiple tables are not supported on MongoDB.")
141+
pass
142+
# raise NotSupportedError("Queries with multiple tables are not supported on MongoDB.")
117143
if any(
118144
isinstance(a, Aggregate) and not isinstance(a, Count)
119145
for a in self.query.annotations.values()
@@ -145,8 +171,8 @@ def get_count(self, check_exists=False):
145171
def build_query(self, columns=None):
146172
"""Check if the query is supported and prepare a MongoQuery."""
147173
self.check_query()
148-
self.setup_query()
149174
query = self.query_class(self, columns)
175+
query.mongo_lookups = self.get_lookup_clauses()
150176
try:
151177
query.mongo_query = {"$expr": self.query.where.as_mql(self, self.connection)}
152178
except FullResultSet:
@@ -212,8 +238,36 @@ def _get_ordering(self):
212238
field_ordering.append((opts.get_field(name), ascending))
213239
return field_ordering
214240

241+
@property
242+
def collection_name(self):
243+
return self.query.get_meta().db_table
244+
215245
def get_collection(self):
216-
return self.connection.get_collection(self.query.get_meta().db_table)
246+
return self.connection.get_collection(self.collection_name)
247+
248+
def get_lookup_clauses(self):
249+
result = []
250+
for alias in tuple(self.query.alias_map):
251+
if not self.query.alias_refcount[alias] or self.collection_name == alias:
252+
continue
253+
254+
from_clause = self.query.alias_map[alias]
255+
clause_mql = from_clause.as_mql(self, self.connection)
256+
result += clause_mql
257+
258+
"""
259+
for t in self.query.extra_tables:
260+
alias, _ = self.query.table_alias(t)
261+
# Only add the alias if it's not already present (the table_alias()
262+
# call increments the refcount, so an alias refcount of one means
263+
# this is the only reference).
264+
if (
265+
alias not in self.query.alias_map
266+
or self.query.alias_refcount[alias] == 1
267+
):
268+
result.append(", %s" % self.quote_name_unless_alias(alias))
269+
"""
270+
return result
217271

218272

219273
class SQLInsertCompiler(SQLCompiler):

django_mongodb/datastructures.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from django.db.models.sql.constants import INNER
2+
from django.db.models.sql.datastructures import Join
3+
4+
5+
def join(self, compiler, connection):
6+
lookups_pipeline = []
7+
join_fields = self.join_fields or self.join_cols
8+
lhs_fields = []
9+
rhs_fields = []
10+
for lhs, rhs in join_fields:
11+
if isinstance(lhs, str):
12+
lhs_mql = lhs
13+
rhs_mql = rhs
14+
else:
15+
lhs, rhs = connection.ops.prepare_join_on_clause(
16+
self.parent_alias, lhs, self.table_name, rhs
17+
)
18+
lhs_mql = lhs.as_mql(compiler, connection)
19+
rhs_mql = rhs.as_mql(compiler, connection)
20+
# replace prefix, in lookup stages the reference
21+
# to this column is without the collection name.
22+
rhs_mql = rhs_mql.replace(f"{self.table_name}.", "", 1)
23+
lhs_fields.append(lhs_mql)
24+
rhs_fields.append(rhs_mql)
25+
26+
# temp_table_name = f"{self.table_alias}__array"
27+
parent_template = "parent__field__"
28+
lookups_pipeline = [
29+
{
30+
"$lookup": {
31+
"from": self.table_name,
32+
"let": {
33+
f"{parent_template}{i}": parent_field
34+
for i, parent_field in enumerate(lhs_fields)
35+
},
36+
"pipeline": [
37+
{
38+
"$match": {
39+
"$expr": {
40+
"$and": [
41+
{"$eq": [f"$${parent_template}{i}", field]}
42+
for i, field in enumerate(rhs_fields)
43+
]
44+
}
45+
}
46+
}
47+
],
48+
"as": self.table_alias,
49+
}
50+
},
51+
]
52+
if self.join_type != INNER:
53+
lookups_pipeline.append(
54+
{
55+
"$project": {
56+
self.table_alias: {
57+
"$cond": {
58+
"if": {
59+
"$or": [
60+
{"$eq": [{"$type": "$arrayField"}, "missing"]},
61+
{"$eq": [{"$size": "$arrayField"}, 0]},
62+
]
63+
},
64+
"then": [None],
65+
"else": "$arrayField",
66+
}
67+
}
68+
}
69+
}
70+
)
71+
lookups_pipeline.append({"$unwind": f"${self.table_alias}"})
72+
return lookups_pipeline
73+
74+
75+
def register_datastructures():
76+
Join.as_mql = join

django_mongodb/expressions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def case(self, compiler, connection):
4343

4444

4545
def col(self, compiler, connection): # noqa: ARG001
46-
return f"${self.target.column}"
46+
prefix = f"{self.alias}." if self.alias != compiler.collection_name else ""
47+
return f"${prefix}{self.target.column}"
4748

4849

4950
def combined_expression(self, compiler, connection):

django_mongodb/features.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ def django_test_expected_failures(self):
297297
"ordering.tests.OrderingTests.test_extra_ordering",
298298
"ordering.tests.OrderingTests.test_extra_ordering_quoting",
299299
"ordering.tests.OrderingTests.test_extra_ordering_with_table_name",
300+
"select_related.tests.SelectRelatedTests.test_select_related_with_extra",
300301
},
301302
"Queries with multiple tables are not supported.": {
302303
"annotations.tests.AliasTests.test_alias_default_alias_expression",

django_mongodb/operations.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from django.conf import settings
88
from django.db import DataError
99
from django.db.backends.base.operations import BaseDatabaseOperations
10-
from django.db.models.expressions import Combinable
10+
from django.db.models.expressions import Col, Combinable
1111
from django.utils import timezone
1212
from django.utils.regex_helper import _lazy_re_compile
1313

@@ -159,6 +159,12 @@ def execute_sql_flush(self, tables):
159159
if not options.get("capped", False):
160160
collection.drop()
161161

162+
def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
163+
lhs_expr = Col(lhs_table, lhs_field)
164+
rhs_expr = Col(rhs_table, rhs_field)
165+
166+
return lhs_expr, rhs_expr
167+
162168
def prep_lookup_value(self, value, field, lookup):
163169
"""
164170
Perform type-conversion on `value` before using as a filter parameter.

django_mongodb/query.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,13 @@ def __init__(self, compiler, columns):
4040
self.columns = columns
4141
self._negated = False
4242
self.ordering = []
43+
self.collection_name = self.compiler.collection_name
4344
self.collection = self.compiler.get_collection()
4445
self.mongo_query = getattr(compiler.query, "raw_query", {})
46+
# maybe I have to create a new object or named tuple.
47+
# it will save lookups, some filters (in case of inner) and project to rename field
48+
# don't know if the rename is needed
49+
self.mongo_lookups = None
4550

4651
def __repr__(self):
4752
return f"<MongoQuery: {self.mongo_query!r} ORDER {self.ordering!r}>"
@@ -106,11 +111,22 @@ def get_cursor(self):
106111
# If name != column, then this is an annotatation referencing
107112
# another column.
108113
fields[name] = 1 if name == column else f"${column}"
114+
115+
# add the subquery tables. if fields is defined
116+
related_fields = {}
117+
if fields:
118+
for alias in self.query.alias_map:
119+
if self.query.alias_refcount[alias] > 0 and self.collection_name != alias:
120+
related_fields[alias] = 1
121+
109122
pipeline = []
123+
if self.mongo_lookups:
124+
lookups = self.mongo_lookups
125+
pipeline.extend(lookups)
110126
if self.mongo_query:
111127
pipeline.append({"$match": self.mongo_query})
112128
if fields:
113-
pipeline.append({"$project": fields})
129+
pipeline.append({"$project": {**fields, **related_fields}})
114130
if self.ordering:
115131
pipeline.append({"$sort": dict(self.ordering)})
116132
if self.query.low_mark > 0:

0 commit comments

Comments
 (0)