@@ -19,10 +19,18 @@ class SQLCompiler(compiler.SQLCompiler):
19
19
query_class = MongoQuery
20
20
_group_pipeline = None
21
21
22
+ @staticmethod
23
+ def _random_separtor ():
24
+ import random
25
+ import string
26
+
27
+ size = 6
28
+ chars = string .ascii_uppercase + string .digits
29
+ return "" .join (random .choice (chars ) for _ in range (size )) # noqa: S311
30
+
22
31
def pre_sql_setup (self , with_col_aliases = False ):
23
32
pre_setup = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
24
33
self .annotations = {}
25
- # mongo_having = self.having.copy() if self.having else None
26
34
group = {}
27
35
group_expressions = set ()
28
36
aggregation_idx = 1
@@ -38,13 +46,14 @@ def pre_sql_setup(self, with_col_aliases=False):
38
46
aggregation_idx += 1
39
47
else :
40
48
alias = target
41
- group_expressions |= set (sub_expr .get_group_by_cols ())
42
49
group [alias ] = sub_expr .as_mql (self , self .connection )
43
50
column_target = expr .output_field .__class__ ()
51
+ column_target .db_column = alias
44
52
column_target .set_attributes_from_name (alias )
45
53
replacements [sub_expr ] = Col (self .collection_name , column_target )
46
54
result_expr = expr .replace_expressions (replacements )
47
55
all_replacements .update (replacements )
56
+ group_expressions |= set (expr .get_group_by_cols ())
48
57
self .annotations [target ] = result_expr
49
58
if group :
50
59
order_by = self .get_order_by ()
@@ -59,24 +68,69 @@ def pre_sql_setup(self, with_col_aliases=False):
59
68
if isinstance (self .query .group_by , tuple | list ):
60
69
group_expressions |= set (self .query .group_by )
61
70
71
+ all_strings = "" .join (
72
+ str (col .as_mql (self , self .connection )) for col in group_expressions
73
+ )
74
+
75
+ while True :
76
+ random_string = self ._random_separtor ()
77
+ if random_string not in all_strings :
78
+ break
79
+ SEPARATOR = f"__{ random_string } __"
80
+
81
+ def _ccc (col ):
82
+ if self .collection_name == col .alias :
83
+ return col .target .column
84
+ return f"{ col .alias } { SEPARATOR } { col .target .column } "
85
+
62
86
ids = (
63
87
None
64
88
if not group_expressions
65
89
else {
66
- col . target . column : col .as_mql (self , self .connection )
90
+ _ccc ( col ) : col .as_mql (self , self .connection )
67
91
# expression aren't needed in the group by clouse ()
68
92
for col in group_expressions
69
93
if isinstance (col , Col )
70
94
}
71
95
)
72
- group ["_id" ] = ids
73
- pipeline = [{"$group" : group }]
74
- if ids :
96
+ pipeline = []
97
+ if ids is None :
98
+ group ["_id" ] = None
99
+ pipeline .append ({"$facet" : {"group" : [{"$group" : group }]}})
75
100
pipeline .append (
76
- {"$addFields" : {key : f"$_id.{ value [1 :]} " for key , value in ids .items ()}}
101
+ {
102
+ "$project" : {
103
+ key : {
104
+ "$getField" : {
105
+ "input" : {"$arrayElemAt" : ["$group" , 0 ]},
106
+ "field" : key ,
107
+ }
108
+ }
109
+ for key in group
110
+ }
111
+ }
77
112
)
78
- if "_id" not in ids :
113
+ else :
114
+ group ["_id" ] = ids
115
+ pipeline .append ({"$group" : group })
116
+ sets = {}
117
+ for key in ids :
118
+ value = f"$_id.{ key } "
119
+ if SEPARATOR in key :
120
+ subtable , field = key .split (SEPARATOR )
121
+ if subtable not in sets :
122
+ sets [subtable ] = {}
123
+ sets [subtable ][field ] = value
124
+ else :
125
+ sets [key ] = value
126
+
127
+ pipeline .append (
128
+ # {"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
129
+ {"$addFields" : sets }
130
+ )
131
+ if "_id" not in sets :
79
132
pipeline .append ({"$unset" : "_id" })
133
+
80
134
if self .having :
81
135
pipeline .append (
82
136
{
@@ -250,14 +304,14 @@ def build_query(self, columns=None):
250
304
query = self .query_class (self )
251
305
query .aggregation_stage = self .get_aggregation_pipeline ()
252
306
query .lookup_pipeline = self .get_lookup_pipeline ()
253
- query .project_fields = self .get_project_fields (columns )
307
+ query .order_by (self ._get_ordering ())
308
+ query .project_fields = self .get_project_fields (columns , ordering = query .ordering )
254
309
try :
255
310
query .mongo_query = (
256
311
{"$expr" : self .where .as_mql (self , self .connection )} if self .where else None
257
312
)
258
313
except FullResultSet :
259
314
query .mongo_query = {}
260
- query .order_by (self ._get_ordering ())
261
315
return query
262
316
263
317
def get_columns (self ):
@@ -369,7 +423,7 @@ def _get_aggregate_expressions(self, expr):
369
423
def get_aggregation_pipeline (self ):
370
424
return self ._group_pipeline
371
425
372
- def get_project_fields (self , columns = None ):
426
+ def get_project_fields (self , columns = None , ordering = None ):
373
427
fields = {}
374
428
for name , expr in columns or []:
375
429
try :
@@ -393,6 +447,10 @@ def get_project_fields(self, columns=None):
393
447
if self .query .alias_refcount [alias ] and self .collection_name != alias :
394
448
fields [alias ] = 1
395
449
450
+ for column , _ in ordering or []:
451
+ if column not in fields :
452
+ fields [column ] = 1
453
+
396
454
return fields
397
455
398
456
@@ -513,7 +571,16 @@ def build_query(self, columns=None):
513
571
elide_empty = self .elide_empty ,
514
572
)
515
573
compiler .pre_sql_setup (with_col_aliases = False )
516
- query .sub_query = compiler .build_query ()
574
+ columns = (
575
+ compiler .get_columns ()
576
+ if compiler .query .annotations or not compiler .query .default_cols
577
+ else None
578
+ )
579
+ subquery = compiler .build_query (
580
+ # Avoid $project (columns=None) if unneeded.
581
+ columns
582
+ )
583
+ query .subquery = subquery
517
584
return query
518
585
519
586
def _make_result (self , result , columns = None ):
0 commit comments