Skip to content

Commit c1b2f3d

Browse files
committed
Support lookup from related collections (#1)
Add support to select_related.
1 parent 714bb94 commit c1b2f3d

File tree

6 files changed

+104
-3
lines changed

6 files changed

+104
-3
lines changed

django_mongodb/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
check_django_compatability()
88

9+
from .datastructures import register_datastructures # noqa: E402
910
from .expressions import register_expressions # noqa: E402
1011
from .fields import register_fields # noqa: E402
1112
from .functions import register_functions # noqa: E402
1213
from .lookups import register_lookups # noqa: E402
1314
from .query import register_nodes # noqa: E402
1415

16+
register_datastructures()
1517
register_expressions()
1618
register_fields()
1719
register_functions()

django_mongodb/compiler.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,15 @@ def _make_result(self, entity, columns):
7373
The entity is assumed to be a dict using field database column
7474
names as keys.
7575
"""
76+
result = self._project_result(entity, columns, converters, tuple_expected)
77+
# Related columns
78+
for relation, columns in related_columns:
79+
result += self._project_result(entity[relation], columns, converters, tuple_expected)
80+
if tuple_expected:
81+
result = tuple(result)
82+
return result
83+
84+
def _project_result(self, entity, columns, converters, tuple_expected=False):
7685
result = []
7786
for name, col in columns:
7887
column_alias = getattr(col, "alias", None)

django_mongodb/datastructures.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from django.db.models.sql.constants import INNER
2+
from django.db.models.sql.datastructures import Join
3+
4+
5+
def join(self, compiler, connection):
6+
lookups_pipeline = []
7+
join_fields = self.join_fields or self.join_cols
8+
lhs_fields = []
9+
rhs_fields = []
10+
for lhs, rhs in join_fields:
11+
if isinstance(lhs, str):
12+
lhs_mql = lhs
13+
rhs_mql = rhs
14+
else:
15+
lhs, rhs = connection.ops.prepare_join_on_clause(
16+
self.parent_alias, lhs, self.table_name, rhs
17+
)
18+
lhs_mql = lhs.as_mql(compiler, connection)
19+
rhs_mql = rhs.as_mql(compiler, connection)
20+
# replace prefix, in lookup stages the reference
21+
# to this column is without the collection name.
22+
rhs_mql = rhs_mql.replace(f"{self.table_name}.", "", 1)
23+
lhs_fields.append(lhs_mql)
24+
rhs_fields.append(rhs_mql)
25+
26+
# temp_table_name = f"{self.table_alias}__array"
27+
parent_template = "parent__field__"
28+
lookups_pipeline = [
29+
{
30+
"$lookup": {
31+
"from": self.table_name,
32+
"let": {
33+
f"{parent_template}{i}": parent_field
34+
for i, parent_field in enumerate(lhs_fields)
35+
},
36+
"pipeline": [
37+
{
38+
"$match": {
39+
"$expr": {
40+
"$and": [
41+
{"$eq": [f"$${parent_template}{i}", field]}
42+
for i, field in enumerate(rhs_fields)
43+
]
44+
}
45+
}
46+
}
47+
],
48+
"as": self.table_alias,
49+
}
50+
},
51+
]
52+
if self.join_type != INNER:
53+
lookups_pipeline.append(
54+
{
55+
"$project": {
56+
self.table_alias: {
57+
"$cond": {
58+
"if": {
59+
"$or": [
60+
{"$eq": [{"$type": "$arrayField"}, "missing"]},
61+
{"$eq": [{"$size": "$arrayField"}, 0]},
62+
]
63+
},
64+
"then": [None],
65+
"else": "$arrayField",
66+
}
67+
}
68+
}
69+
}
70+
)
71+
lookups_pipeline.append({"$unwind": f"${self.table_alias}"})
72+
return lookups_pipeline
73+
74+
75+
def register_datastructures():
76+
Join.as_mql = join

django_mongodb/expressions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def case(self, compiler, connection):
4444

4545

4646
def col(self, compiler, connection): # noqa: ARG001
47-
# Add the column's collection's alias for columns in joined collections.
4847
prefix = f"{self.alias}." if self.alias != compiler.collection_name else ""
4948
return f"${prefix}{self.target.column}"
5049

django_mongodb/operations.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from django.conf import settings
99
from django.db import DataError
1010
from django.db.backends.base.operations import BaseDatabaseOperations
11-
from django.db.models.expressions import Combinable
11+
from django.db.models.expressions import Col, Combinable
1212
from django.utils import timezone
1313
from django.utils.regex_helper import _lazy_re_compile
1414

@@ -176,6 +176,12 @@ def execute_sql_flush(self, tables):
176176
if not options.get("capped", False):
177177
collection.drop()
178178

179+
def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
180+
lhs_expr = Col(lhs_table, lhs_field)
181+
rhs_expr = Col(rhs_table, rhs_field)
182+
183+
return lhs_expr, rhs_expr
184+
179185
def prep_lookup_value(self, value, field, lookup):
180186
"""
181187
Perform type-conversion on `value` before using as a filter parameter.

django_mongodb/query.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(self, compiler, columns):
4545
self.columns = columns
4646
self._negated = False
4747
self.ordering = []
48+
self.collection_name = self.compiler.collection_name
4849
self.collection = self.compiler.get_collection()
4950
self.collection_name = self.compiler.collection_name
5051
self.mongo_query = getattr(compiler.query, "raw_query", {})
@@ -123,10 +124,18 @@ def get_cursor(self):
123124
pipeline = []
124125
if self.lookup_pipeline:
125126
pipeline.extend(self.lookup_pipeline)
127+
128+
# add the subquery tables. if fields is defined
129+
related_fields = {}
130+
if fields:
131+
for alias in self.query.alias_map:
132+
if self.query.alias_refcount[alias] > 0 and self.collection_name != alias:
133+
related_fields[alias] = 1
134+
126135
if self.mongo_query:
127136
pipeline.append({"$match": self.mongo_query})
128137
if fields:
129-
pipeline.append({"$project": fields})
138+
pipeline.append({"$project": {**fields, **related_fields}})
130139
if self.ordering:
131140
pipeline.append({"$sort": dict(self.ordering)})
132141
if self.query.low_mark > 0:

0 commit comments

Comments
 (0)