Skip to content

Commit e222400

Browse files
committed
Support lookup from related collections (#1)
Add support to select_related.
1 parent 9019a33 commit e222400

File tree

9 files changed

+168
-12
lines changed

9 files changed

+168
-12
lines changed

.github/workflows/test-python.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ jobs:
9191
model_fields
9292
or_lookups
9393
queries.tests.Ticket12807Tests.test_ticket_12807
94+
select_related
9495
sessions_tests
9596
timezones
9697
update

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ Migrations for 'admin':
114114
- `datetimes()`
115115
- `distinct()`
116116
- `extra()`
117-
- `select_related()`
118117

119118
- Queries with joins aren't supported.
120119

django_mongodb/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
check_django_compatability()
88

9+
from .datastructures import register_structures # noqa: E402
910
from .expressions import register_expressions # noqa: E402
1011
from .fields import register_fields # noqa: E402
1112
from .functions import register_functions # noqa: E402
1213
from .lookups import register_lookups # noqa: E402
1314
from .query import register_nodes # noqa: E402
1415

16+
register_structures()
1517
register_expressions()
1618
register_fields()
1719
register_functions()

django_mongodb/compiler.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,33 @@ def results_iter(
4444
"""
4545
columns = self.get_columns()
4646

47+
related_columns = []
4748
if results is None:
4849
# QuerySet.values() or values_list()
4950
try:
5051
results = self.build_query(columns).fetch()
5152
except EmptyResultSet:
5253
results = []
54+
else:
55+
index = len(columns)
56+
while index < self.col_count:
57+
foreign_columns = []
58+
foreign_relation = self.select[index][0].alias
59+
while index < self.col_count and foreign_relation == self.select[index][0].alias:
60+
foreign_columns.append(self.select[index][0])
61+
index += 1
62+
related_columns.append(
63+
(
64+
foreign_relation,
65+
[(column.target.column, column) for column in foreign_columns],
66+
)
67+
)
5368

5469
converters = self.get_converters(columns)
5570
for entity in results:
56-
yield self._make_result(entity, columns, converters, tuple_expected=tuple_expected)
71+
yield self._make_result(
72+
entity, columns, related_columns, converters, tuple_expected=tuple_expected
73+
)
5774

5875
def has_results(self):
5976
return bool(self.get_count(check_exists=True))
@@ -72,13 +89,22 @@ def get_converters(self, expressions):
7289
converters[name] = backend_converters + field_converters
7390
return converters
7491

75-
def _make_result(self, entity, columns, converters, tuple_expected=False):
92+
def _make_result(self, entity, columns, related_columns, converters, tuple_expected=False):
7693
"""
7794
Decode values for the given fields from the database entity.
7895
7996
The entity is assumed to be a dict using field database column
8097
names as keys.
8198
"""
99+
result = self._project_result(entity, columns, converters, tuple_expected)
100+
# Related columns
101+
for relation, columns in related_columns:
102+
result += self._project_result(entity[relation], columns, converters, tuple_expected)
103+
if tuple_expected:
104+
result = tuple(result)
105+
return result
106+
107+
def _project_result(self, entity, columns, converters, tuple_expected=False):
82108
result = []
83109
for name, col in columns:
84110
field = col.field
@@ -90,8 +116,6 @@ def _make_result(self, entity, columns, converters, tuple_expected=False):
90116
for converter in converters.get(name, ()):
91117
value = converter(value, col, self.connection)
92118
result.append(value)
93-
if tuple_expected:
94-
result = tuple(result)
95119
return result
96120

97121
def check_query(self):
@@ -111,9 +135,11 @@ def check_query(self):
111135
if self.query.extra:
112136
raise NotSupportedError("QuerySet.extra() is not supported on MongoDB.")
113137
if self.query.select_related:
114-
raise NotSupportedError("QuerySet.select_related() is not supported on MongoDB.")
138+
pass
139+
# raise NotSupportedError("QuerySet.select_related() is not supported on MongoDB.")
115140
if len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) > 1:
116-
raise NotSupportedError("Queries with multiple tables are not supported on MongoDB.")
141+
pass
142+
# raise NotSupportedError("Queries with multiple tables are not supported on MongoDB.")
117143
if any(
118144
isinstance(a, Aggregate) and not isinstance(a, Count)
119145
for a in self.query.annotations.values()
@@ -135,8 +161,8 @@ def get_count(self, check_exists=False):
135161
def build_query(self, columns=None):
136162
"""Check if the query is supported and prepare a MongoQuery."""
137163
self.check_query()
138-
self.setup_query()
139164
query = self.query_class(self, columns)
165+
query.mongo_lookups = self.get_lookup_clauses()
140166
try:
141167
query.mongo_query = {"$expr": self.query.where.as_mql(self, self.connection)}
142168
except FullResultSet:
@@ -202,8 +228,36 @@ def _get_ordering(self):
202228
field_ordering.append((opts.get_field(name), ascending))
203229
return field_ordering
204230

231+
@property
232+
def collection_name(self):
233+
return self.query.get_meta().db_table
234+
205235
def get_collection(self):
206-
return self.connection.get_collection(self.query.get_meta().db_table)
236+
return self.connection.get_collection(self.collection_name)
237+
238+
def get_lookup_clauses(self):
239+
result = []
240+
for alias in tuple(self.query.alias_map):
241+
if not self.query.alias_refcount[alias] or self.collection_name == alias:
242+
continue
243+
244+
from_clause = self.query.alias_map[alias]
245+
clause_mql = from_clause.as_mql(self, self.connection)
246+
result += clause_mql
247+
248+
"""
249+
for t in self.query.extra_tables:
250+
alias, _ = self.query.table_alias(t)
251+
# Only add the alias if it's not already present (the table_alias()
252+
# call increments the refcount, so an alias refcount of one means
253+
# this is the only reference).
254+
if (
255+
alias not in self.query.alias_map
256+
or self.query.alias_refcount[alias] == 1
257+
):
258+
result.append(", %s" % self.quote_name_unless_alias(alias))
259+
"""
260+
return result
207261

208262

209263
class SQLInsertCompiler(SQLCompiler):

django_mongodb/datastructures.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from django.db.models.sql.constants import INNER
2+
from django.db.models.sql.datastructures import Join
3+
4+
5+
def join(self, compiler, connection):
6+
lookups_pipeline = []
7+
join_fields = self.join_fields or self.join_cols
8+
lhs_fields = []
9+
rhs_fields = []
10+
for lhs, rhs in join_fields:
11+
if isinstance(lhs, str):
12+
lhs_mql = lhs
13+
rhs_mql = rhs
14+
else:
15+
lhs, rhs = connection.ops.prepare_join_on_clause(
16+
self.parent_alias, lhs, self.table_name, rhs
17+
)
18+
lhs_mql = lhs.as_mql(compiler, connection)
19+
rhs_mql = rhs.as_mql(compiler, connection)
20+
# replace prefix, in lookup stages the reference
21+
# to this column is without the collection name.
22+
rhs_mql = rhs_mql.replace(f"{self.table_name}.", "", 1)
23+
lhs_fields.append(lhs_mql)
24+
rhs_fields.append(rhs_mql)
25+
26+
# temp_table_name = f"{self.table_alias}__array"
27+
parent_template = "parent__field__"
28+
lookups_pipeline = [
29+
{
30+
"$lookup": {
31+
"from": self.table_name,
32+
"let": {
33+
f"{parent_template}{i}": parent_field
34+
for i, parent_field in enumerate(lhs_fields)
35+
},
36+
"pipeline": [
37+
{
38+
"$match": {
39+
"$expr": {
40+
"$and": [
41+
{"$eq": [f"$${parent_template}{i}", field]}
42+
for i, field in enumerate(rhs_fields)
43+
]
44+
}
45+
}
46+
}
47+
],
48+
"as": self.table_alias,
49+
}
50+
},
51+
]
52+
if self.join_type != INNER:
53+
lookups_pipeline.append(
54+
{
55+
"$project": {
56+
self.table_alias: {
57+
"$cond": {
58+
"if": {
59+
"$or": [
60+
{"$eq": [{"$type": "$arrayField"}, "missing"]},
61+
{"$eq": [{"$size": "$arrayField"}, 0]},
62+
]
63+
},
64+
"then": [None],
65+
"else": "$arrayField",
66+
}
67+
}
68+
}
69+
}
70+
)
71+
lookups_pipeline.append({"$unwind": f"${self.table_alias}"})
72+
return lookups_pipeline
73+
74+
75+
def register_structures():
76+
Join.as_mql = join

django_mongodb/expressions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def case(self, compiler, connection):
4040

4141

4242
def col(self, compiler, connection): # noqa: ARG001
43-
return f"${self.target.column}"
43+
prefix = f"{self.alias}." if self.alias != compiler.collection_name else ""
44+
return f"${prefix}{self.target.column}"
4445

4546

4647
def combined_expression(self, compiler, connection):

django_mongodb/features.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def django_test_expected_failures(self):
288288
"defer.tests.DeferTests.test_defer_extra",
289289
"lookup.tests.LookupTests.test_values",
290290
"lookup.tests.LookupTests.test_values_list",
291+
"select_related.tests.SelectRelatedTests.test_select_related_with_extra",
291292
},
292293
"Queries with multiple tables are not supported.": {
293294
"annotations.tests.AliasTests.test_alias_default_alias_expression",

django_mongodb/operations.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from django.conf import settings
88
from django.db import DataError
99
from django.db.backends.base.operations import BaseDatabaseOperations
10-
from django.db.models.expressions import Combinable
10+
from django.db.models.expressions import Col, Combinable
1111
from django.utils import timezone
1212
from django.utils.regex_helper import _lazy_re_compile
1313

@@ -154,6 +154,12 @@ def execute_sql_flush(self, tables):
154154
if not options.get("capped", False):
155155
collection.drop()
156156

157+
def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
158+
lhs_expr = Col(lhs_table, lhs_field)
159+
rhs_expr = Col(rhs_table, rhs_field)
160+
161+
return lhs_expr, rhs_expr
162+
157163
def prep_lookup_value(self, value, field, lookup):
158164
"""
159165
Perform type-conversion on `value` before using as a filter parameter.

django_mongodb/query.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,13 @@ def __init__(self, compiler, columns):
4040
self.columns = columns
4141
self._negated = False
4242
self.ordering = []
43+
self.collection_name = self.compiler.collection_name
4344
self.collection = self.compiler.get_collection()
4445
self.mongo_query = getattr(compiler.query, "raw_query", {})
46+
# maybe I have to create a new object or named tuple.
47+
# it will save lookups, some filters (in case of inner) and project to rename field
48+
# don't know if the rename is needed
49+
self.mongo_lookups = None
4550

4651
def __repr__(self):
4752
return f"<MongoQuery: {self.mongo_query!r} ORDER {self.ordering!r}>"
@@ -102,11 +107,22 @@ def get_cursor(self):
102107
# If name != column, then this is an annotatation referencing
103108
# another column.
104109
fields[name] = 1 if name == column else f"${column}"
110+
111+
# add the subquery tables. if fields is defined
112+
related_fields = {}
113+
if fields:
114+
for alias in self.query.alias_map:
115+
if self.query.alias_refcount[alias] > 0 and self.collection_name != alias:
116+
related_fields[alias] = 1
117+
105118
pipeline = []
119+
if self.mongo_lookups:
120+
lookups = self.mongo_lookups
121+
pipeline.extend(lookups)
106122
if self.mongo_query:
107123
pipeline.append({"$match": self.mongo_query})
108124
if fields:
109-
pipeline.append({"$project": fields})
125+
pipeline.append({"$project": {**fields, **related_fields}})
110126
if self.ordering:
111127
pipeline.append({"$sort": dict(self.ordering)})
112128
if self.query.low_mark > 0:

0 commit comments

Comments
 (0)