@@ -18,10 +18,18 @@ class SQLCompiler(compiler.SQLCompiler):
18
18
query_class = MongoQuery
19
19
_group_pipeline = None
20
20
21
+ @staticmethod
22
+ def _random_separtor ():
23
+ import random
24
+ import string
25
+
26
+ size = 6
27
+ chars = string .ascii_uppercase + string .digits
28
+ return "" .join (random .choice (chars ) for _ in range (size )) # noqa: S311
29
+
21
30
def pre_sql_setup (self , with_col_aliases = False ):
22
31
pre_setup = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
23
32
self .annotations = {}
24
- # mongo_having = self.having.copy() if self.having else None
25
33
group = {}
26
34
group_expressions = set ()
27
35
aggregation_idx = 1
@@ -37,13 +45,14 @@ def pre_sql_setup(self, with_col_aliases=False):
37
45
aggregation_idx += 1
38
46
else :
39
47
alias = target
40
- group_expressions |= set (sub_expr .get_group_by_cols ())
41
48
group [alias ] = sub_expr .as_mql (self , self .connection )
42
49
column_target = expr .output_field .__class__ ()
50
+ column_target .db_column = alias
43
51
column_target .set_attributes_from_name (alias )
44
52
replacements [sub_expr ] = Col (self .collection_name , column_target )
45
53
result_expr = expr .replace_expressions (replacements )
46
54
all_replacements .update (replacements )
55
+ group_expressions |= set (expr .get_group_by_cols ())
47
56
self .annotations [target ] = result_expr
48
57
if group :
49
58
order_by = self .get_order_by ()
@@ -58,24 +67,69 @@ def pre_sql_setup(self, with_col_aliases=False):
58
67
if isinstance (self .query .group_by , tuple | list ):
59
68
group_expressions |= set (self .query .group_by )
60
69
70
+ all_strings = "" .join (
71
+ str (col .as_mql (self , self .connection )) for col in group_expressions
72
+ )
73
+
74
+ while True :
75
+ random_string = self ._random_separtor ()
76
+ if random_string not in all_strings :
77
+ break
78
+ SEPARATOR = f"__{ random_string } __"
79
+
80
+ def _ccc (col ):
81
+ if self .collection_name == col .alias :
82
+ return col .target .column
83
+ return f"{ col .alias } { SEPARATOR } { col .target .column } "
84
+
61
85
ids = (
62
86
None
63
87
if not group_expressions
64
88
else {
65
- col . target . column : col .as_mql (self , self .connection )
89
+ _ccc ( col ) : col .as_mql (self , self .connection )
66
90
# expression aren't needed in the group by clouse ()
67
91
for col in group_expressions
68
92
if isinstance (col , Col )
69
93
}
70
94
)
71
- group ["_id" ] = ids
72
- pipeline = [{"$group" : group }]
73
- if ids :
95
+ pipeline = []
96
+ if ids is None :
97
+ group ["_id" ] = None
98
+ pipeline .append ({"$facet" : {"group" : [{"$group" : group }]}})
74
99
pipeline .append (
75
- {"$addFields" : {key : f"$_id.{ value [1 :]} " for key , value in ids .items ()}}
100
+ {
101
+ "$project" : {
102
+ key : {
103
+ "$getField" : {
104
+ "input" : {"$arrayElemAt" : ["$group" , 0 ]},
105
+ "field" : key ,
106
+ }
107
+ }
108
+ for key in group
109
+ }
110
+ }
76
111
)
77
- if "_id" not in ids :
112
+ else :
113
+ group ["_id" ] = ids
114
+ pipeline .append ({"$group" : group })
115
+ sets = {}
116
+ for key in ids :
117
+ value = f"$_id.{ key } "
118
+ if SEPARATOR in key :
119
+ subtable , field = key .split (SEPARATOR )
120
+ if subtable not in sets :
121
+ sets [subtable ] = {}
122
+ sets [subtable ][field ] = value
123
+ else :
124
+ sets [key ] = value
125
+
126
+ pipeline .append (
127
+ # {"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
128
+ {"$addFields" : sets }
129
+ )
130
+ if "_id" not in sets :
78
131
pipeline .append ({"$unset" : "_id" })
132
+
79
133
if self .having :
80
134
pipeline .append (
81
135
{
@@ -224,14 +278,14 @@ def build_query(self, columns=None):
224
278
query = self .query_class (self )
225
279
query .aggregation_stage = self .get_aggregation_pipeline ()
226
280
query .lookup_pipeline = self .get_lookup_pipeline ()
227
- query .project_fields = self .get_project_fields (columns )
281
+ query .order_by (self ._get_ordering ())
282
+ query .project_fields = self .get_project_fields (columns , ordering = query .ordering )
228
283
try :
229
284
query .mongo_query = (
230
285
{"$expr" : self .where .as_mql (self , self .connection )} if self .where else None
231
286
)
232
287
except FullResultSet :
233
288
query .mongo_query = {}
234
- query .order_by (self ._get_ordering ())
235
289
return query
236
290
237
291
def get_columns (self ):
@@ -337,7 +391,7 @@ def _get_aggregate_expressions(self, expr):
337
391
def get_aggregation_pipeline (self ):
338
392
return self ._group_pipeline
339
393
340
- def get_project_fields (self , columns = None ):
394
+ def get_project_fields (self , columns = None , ordering = None ):
341
395
fields = {}
342
396
for name , expr in columns or []:
343
397
try :
@@ -361,6 +415,10 @@ def get_project_fields(self, columns=None):
361
415
if self .query .alias_refcount [alias ] and self .collection_name != alias :
362
416
fields [alias ] = 1
363
417
418
+ for column , _ in ordering or []:
419
+ if column not in fields :
420
+ fields [column ] = 1
421
+
364
422
return fields
365
423
366
424
@@ -461,7 +519,16 @@ def build_query(self, columns=None):
461
519
elide_empty = self .elide_empty ,
462
520
)
463
521
compiler .pre_sql_setup (with_col_aliases = False )
464
- query .sub_query = compiler .build_query ()
522
+ columns = (
523
+ compiler .get_columns ()
524
+ if compiler .query .annotations or not compiler .query .default_cols
525
+ else None
526
+ )
527
+ subquery = compiler .build_query (
528
+ # Avoid $project (columns=None) if unneeded.
529
+ columns
530
+ )
531
+ query .subquery = subquery
465
532
return query
466
533
467
534
def _make_result (self , result , columns = None ):
0 commit comments