@@ -19,10 +19,18 @@ class SQLCompiler(compiler.SQLCompiler):
19
19
query_class = MongoQuery
20
20
_group_pipeline = None
21
21
22
+ @staticmethod
23
+ def _random_separtor ():
24
+ import random
25
+ import string
26
+
27
+ size = 6
28
+ chars = string .ascii_uppercase + string .digits
29
+ return "" .join (random .choice (chars ) for _ in range (size )) # noqa: S311
30
+
22
31
def pre_sql_setup (self , with_col_aliases = False ):
23
32
pre_setup = super ().pre_sql_setup (with_col_aliases = with_col_aliases )
24
33
self .annotations = {}
25
- # mongo_having = self.having.copy() if self.having else None
26
34
group = {}
27
35
group_expressions = set ()
28
36
aggregation_idx = 1
@@ -38,13 +46,14 @@ def pre_sql_setup(self, with_col_aliases=False):
38
46
aggregation_idx += 1
39
47
else :
40
48
alias = target
41
- group_expressions |= set (sub_expr .get_group_by_cols ())
42
49
group [alias ] = sub_expr .as_mql (self , self .connection )
43
50
column_target = expr .output_field .__class__ ()
51
+ column_target .db_column = alias
44
52
column_target .set_attributes_from_name (alias )
45
53
replacements [sub_expr ] = Col (self .collection_name , column_target )
46
54
result_expr = expr .replace_expressions (replacements )
47
55
all_replacements .update (replacements )
56
+ group_expressions |= set (expr .get_group_by_cols ())
48
57
self .annotations [target ] = result_expr
49
58
if group :
50
59
order_by = self .get_order_by ()
@@ -59,24 +68,69 @@ def pre_sql_setup(self, with_col_aliases=False):
59
68
if isinstance (self .query .group_by , tuple | list ):
60
69
group_expressions |= set (self .query .group_by )
61
70
71
+ all_strings = "" .join (
72
+ str (col .as_mql (self , self .connection )) for col in group_expressions
73
+ )
74
+
75
+ while True :
76
+ random_string = self ._random_separtor ()
77
+ if random_string not in all_strings :
78
+ break
79
+ SEPARATOR = f"__{ random_string } __"
80
+
81
+ def _ccc (col ):
82
+ if self .collection_name == col .alias :
83
+ return col .target .column
84
+ return f"{ col .alias } { SEPARATOR } { col .target .column } "
85
+
62
86
ids = (
63
87
None
64
88
if not group_expressions
65
89
else {
66
- col . target . column : col .as_mql (self , self .connection )
90
+ _ccc ( col ) : col .as_mql (self , self .connection )
67
91
# expression aren't needed in the group by clouse ()
68
92
for col in group_expressions
69
93
if isinstance (col , Col )
70
94
}
71
95
)
72
- group ["_id" ] = ids
73
- pipeline = [{"$group" : group }]
74
- if ids :
96
+ pipeline = []
97
+ if ids is None :
98
+ group ["_id" ] = None
99
+ pipeline .append ({"$facet" : {"group" : [{"$group" : group }]}})
75
100
pipeline .append (
76
- {"$addFields" : {key : f"$_id.{ value [1 :]} " for key , value in ids .items ()}}
101
+ {
102
+ "$project" : {
103
+ key : {
104
+ "$getField" : {
105
+ "input" : {"$arrayElemAt" : ["$group" , 0 ]},
106
+ "field" : key ,
107
+ }
108
+ }
109
+ for key in group
110
+ }
111
+ }
77
112
)
78
- if "_id" not in ids :
113
+ else :
114
+ group ["_id" ] = ids
115
+ pipeline .append ({"$group" : group })
116
+ sets = {}
117
+ for key in ids :
118
+ value = f"$_id.{ key } "
119
+ if SEPARATOR in key :
120
+ subtable , field = key .split (SEPARATOR )
121
+ if subtable not in sets :
122
+ sets [subtable ] = {}
123
+ sets [subtable ][field ] = value
124
+ else :
125
+ sets [key ] = value
126
+
127
+ pipeline .append (
128
+ # {"$addFields": {key: f"$_id.{value[1:]}" for key, value in ids.items()}}
129
+ {"$addFields" : sets }
130
+ )
131
+ if "_id" not in sets :
79
132
pipeline .append ({"$unset" : "_id" })
133
+
80
134
if self .having :
81
135
pipeline .append (
82
136
{
@@ -252,14 +306,14 @@ def build_query(self, columns=None):
252
306
query = self .query_class (self )
253
307
query .aggregation_stage = self .get_aggregation_pipeline ()
254
308
query .lookup_pipeline = self .get_lookup_pipeline ()
255
- query .project_fields = self .get_project_fields (columns )
309
+ query .order_by (self ._get_ordering ())
310
+ query .project_fields = self .get_project_fields (columns , ordering = query .ordering )
256
311
try :
257
312
query .mongo_query = (
258
313
{"$expr" : self .where .as_mql (self , self .connection )} if self .where else None
259
314
)
260
315
except FullResultSet :
261
316
query .mongo_query = {}
262
- query .order_by (self ._get_ordering ())
263
317
return query
264
318
265
319
def get_columns (self ):
@@ -371,7 +425,7 @@ def _get_aggregate_expressions(self, expr):
371
425
def get_aggregation_pipeline (self ):
372
426
return self ._group_pipeline
373
427
374
- def get_project_fields (self , columns = None ):
428
+ def get_project_fields (self , columns = None , ordering = None ):
375
429
fields = {}
376
430
for name , expr in columns or []:
377
431
try :
@@ -395,6 +449,10 @@ def get_project_fields(self, columns=None):
395
449
if self .query .alias_refcount [alias ] and self .collection_name != alias :
396
450
fields [alias ] = 1
397
451
452
+ for column , _ in ordering or []:
453
+ if column not in fields :
454
+ fields [column ] = 1
455
+
398
456
return fields
399
457
400
458
@@ -515,7 +573,16 @@ def build_query(self, columns=None):
515
573
elide_empty = self .elide_empty ,
516
574
)
517
575
compiler .pre_sql_setup (with_col_aliases = False )
518
- query .sub_query = compiler .build_query ()
576
+ columns = (
577
+ compiler .get_columns ()
578
+ if compiler .query .annotations or not compiler .query .default_cols
579
+ else None
580
+ )
581
+ subquery = compiler .build_query (
582
+ # Avoid $project (columns=None) if unneeded.
583
+ columns
584
+ )
585
+ query .subquery = subquery
519
586
return query
520
587
521
588
def _make_result (self , result , columns = None ):
0 commit comments