|
| 1 | +import itertools |
| 2 | + |
1 | 3 | from django.core.exceptions import EmptyResultSet, FullResultSet
|
2 |
| -from django.db.models.expressions import Col, Value |
| 4 | +from django.db.models.expressions import RawSQL |
| 5 | +from django.db.models.sql.where import AND |
3 | 6 | from django.db.models.sql import compiler
|
| 7 | +from django.db.models.fields.json import KeyTransform |
| 8 | +from django.db.models.expressions import DatabaseDefault |
| 9 | + |
| 10 | + |
| 11 | +class Flag: |
| 12 | + value = False |
| 13 | + |
| 14 | + def __init__(self, value): |
| 15 | + self.value = value |
| 16 | + |
| 17 | + def __enter__(self): |
| 18 | + self.value = True |
| 19 | + return True |
| 20 | + |
| 21 | + def __exit__(self, *args): |
| 22 | + self.value = False |
| 23 | + |
| 24 | + def __bool__(self): |
| 25 | + return self.value |
| 26 | + |
4 | 27 |
|
5 | 28 | class SQLCompiler(compiler.SQLCompiler):
|
| 29 | + in_get_select = Flag(False) |
| 30 | + in_get_order_by = Flag(False) |
| 31 | +# def get_from_clause(self): |
| 32 | +# result, params = super().get_from_clause() |
| 33 | +# jsoncolumns = {} |
| 34 | +# for column in self.query.select + tuple( |
| 35 | +# [column for column in (col.lhs for col in self.query.where.children)] |
| 36 | +# ): |
| 37 | +# if isinstance(column, KeyTransform): |
| 38 | +# if column.field.name not in jsoncolumns: |
| 39 | +# jsoncolumns[column.field.name] = {} |
| 40 | +# jsoncolumns[column.field.name][ |
| 41 | +# column.key_name |
| 42 | +# ] = f"{column.field.name}__{column.key_name}" |
| 43 | +# [] |
| 44 | +# for field in jsoncolumns: |
| 45 | +# self.query.where.add(RawSQL("%s is not null" % (field,), []), AND) |
| 46 | +# cols = ", ".join( |
| 47 | +# [ |
| 48 | +# f"{jsoncolumns[field][col_name]} VARCHAR PATH '$.{col_name}'" |
| 49 | +# for col_name in jsoncolumns[field] |
| 50 | +# ] |
| 51 | +# ) |
| 52 | +# result.append( |
| 53 | +# f""" |
| 54 | +# , JSON_TABLE("{column.field.name}", '$' COLUMNS( |
| 55 | +# {cols} |
| 56 | +# )) |
| 57 | +# """ |
| 58 | +# ) |
| 59 | + |
| 60 | +# if "model_fields_nullablejsonmodel" in self.query.alias_map: |
| 61 | +# breakpoint() |
| 62 | +# return result, params |
| 63 | + |
| 64 | + def get_select(self, with_col_aliases=False): |
| 65 | + with self.in_get_select: |
| 66 | + return super().get_select(with_col_aliases) |
| 67 | + |
| 68 | + def get_order_by(self): |
| 69 | + with self.in_get_order_by: |
| 70 | + return super().get_order_by() |
6 | 71 |
|
7 | 72 | def as_sql(self, with_limits=True, with_col_aliases=False):
|
8 | 73 | with_limit_offset = (with_limits or self.query.is_sliced) and (
|
9 | 74 | self.query.high_mark is not None or self.query.low_mark > 0
|
10 | 75 | )
|
11 | 76 | if self.query.select_for_update or not with_limit_offset:
|
12 |
| - return super().as_sql(with_limits, with_col_aliases) |
| 77 | + query, params = super().as_sql(with_limits, with_col_aliases) |
| 78 | + return query, params |
13 | 79 | try:
|
14 | 80 | extra_select, order_by, group_by = self.pre_sql_setup()
|
15 | 81 |
|
@@ -79,7 +145,10 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
|
79 | 145 | order_by_result = "ORDER BY %s" % first_col
|
80 | 146 |
|
81 | 147 | if offset:
|
82 |
| - out_cols.append("ROW_NUMBER() %s AS row_number" % ("OVER (%s)" % order_by_result if order_by_result else "")) |
| 148 | + out_cols.append( |
| 149 | + "ROW_NUMBER() %s AS row_number" |
| 150 | + % ("OVER (%s)" % order_by_result if order_by_result else "") |
| 151 | + ) |
83 | 152 |
|
84 | 153 | result += [", ".join(out_cols), "FROM", *from_]
|
85 | 154 | params.extend(f_params)
|
@@ -154,19 +223,32 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
|
154 | 223 | ), tuple(sub_params + params)
|
155 | 224 |
|
156 | 225 | if offset:
|
157 |
| - query = "SELECT * FROM (%s) WHERE row_number between %d AND %d ORDER BY row_number" % ( |
158 |
| - query, |
159 |
| - offset, |
160 |
| - limit, |
| 226 | + query = ( |
| 227 | + "SELECT * FROM (%s) WHERE row_number between %d AND %d ORDER BY row_number" |
| 228 | + % ( |
| 229 | + query, |
| 230 | + offset, |
| 231 | + limit, |
| 232 | + ) |
161 | 233 | )
|
162 | 234 | return query, tuple(params)
|
163 |
| - except: |
164 |
| - return super().as_sql(with_limits, with_col_aliases) |
| 235 | + except Exception: |
| 236 | + query, params = super().as_sql(with_limits, with_col_aliases) |
| 237 | + return query, params |
165 | 238 |
|
166 | 239 |
|
167 | 240 | class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
|
168 |
| - |
| 241 | + |
169 | 242 | def as_sql(self):
|
| 243 | + |
| 244 | + if self.query.fields: |
| 245 | + fields = self.query.fields |
| 246 | + self.query.fields = [ |
| 247 | + field |
| 248 | + for field in fields |
| 249 | + if not isinstance(self.pre_save_val(field, self.query.objs[0]), DatabaseDefault) |
| 250 | + ] |
| 251 | + |
170 | 252 | if self.query.fields:
|
171 | 253 | return super().as_sql()
|
172 | 254 |
|
@@ -195,7 +277,7 @@ def as_sql(self):
|
195 | 277 | if self.connection._disable_constraint_checking:
|
196 | 278 | sql = "UPDATE %%NOCHECK" + sql[6:]
|
197 | 279 | return sql, params
|
198 |
| - |
| 280 | + |
199 | 281 |
|
200 | 282 | class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
|
201 | 283 | pass
|
0 commit comments