|
1 | 1 | # Copyright (c) Microsoft Corporation. |
2 | 2 | # Licensed under the MIT license. |
3 | 3 |
|
| 4 | +import json |
| 5 | + |
4 | 6 | from django import VERSION |
5 | 7 | from django.db.models import BooleanField |
6 | 8 | from django.db.models.functions import Cast |
7 | 9 | from django.db.models.functions.math import ATan2, Log, Ln, Mod, Round |
8 | 10 | from django.db.models.expressions import Case, Exists, OrderBy, When |
9 | | -from django.db.models.lookups import Lookup, In, Exact |
| 11 | +from django.db.models.lookups import Lookup, In |
| 12 | +from django.db.models import lookups |
| 13 | + |
10 | 14 | if VERSION >= (3, 1): |
11 | | - from django.db.models.fields.json import KeyTransform, KeyTransformExact |
| 15 | + from django.db.models.fields.json import KeyTransform, KeyTransformIn, KeyTransformExact |
12 | 16 |
|
13 | 17 | DJANGO3 = VERSION[0] >= 3 |
14 | 18 |
|
@@ -110,19 +114,31 @@ def split_parameter_list_as_sql(self, compiler, connection): |
110 | 114 |
|
111 | 115 | return in_clause, () |
112 | 116 |
|
113 | | -def KeyTransformExact_process_rhs(self, compiler, connection): |
| 117 | +def unquote_json_rhs(rhs_params): |
| 118 | + for value in rhs_params: |
| 119 | + value = json.loads(value) |
| 120 | + if not isinstance(value, (list, dict)): |
| 121 | + rhs_params = [param.replace('"', '') for param in rhs_params] |
| 122 | + return rhs_params |
| 123 | + |
| 124 | +def json_KeyTransformExact_process_rhs(self, compiler, connection): |
114 | 125 | if isinstance(self.rhs, KeyTransform): |
115 | | - return super(Exact, self).process_rhs(compiler, connection) |
116 | | - rhs, rhs_params = super(Exact, self).process_rhs(compiler, connection) |
117 | | - if connection.vendor == 'microsoft': |
118 | | - if rhs_params != [None]: |
119 | | - rhs_params = [params.strip('"') for params in rhs_params] |
120 | | - return rhs, rhs_params |
| 126 | + return super(lookups.Exact, self).process_rhs(compiler, connection) |
| 127 | + rhs, rhs_params = super(KeyTransformExact, self).process_rhs(compiler, connection) |
| 128 | + |
| 129 | + return rhs, unquote_json_rhs(rhs_params) |
| 130 | + |
| 131 | +def json_KeyTransformIn(self, compiler, connection): |
| 132 | + lhs, _ = super(KeyTransformIn, self).process_lhs(compiler, connection) |
| 133 | + rhs, rhs_params = super(KeyTransformIn, self).process_rhs(compiler, connection) |
| 134 | + |
| 135 | + return (lhs + ' IN ' + rhs, unquote_json_rhs(rhs_params)) |
121 | 136 |
|
122 | 137 | ATan2.as_microsoft = sqlserver_atan2 |
123 | 138 | In.split_parameter_list_as_sql = split_parameter_list_as_sql |
124 | 139 | if VERSION >= (3, 1): |
125 | | - KeyTransformExact.process_rhs = KeyTransformExact_process_rhs |
| 140 | + KeyTransformIn.as_microsoft = json_KeyTransformIn |
| 141 | + KeyTransformExact.process_rhs = json_KeyTransformExact_process_rhs |
126 | 142 | Ln.as_microsoft = sqlserver_ln |
127 | 143 | Log.as_microsoft = sqlserver_log |
128 | 144 | Mod.as_microsoft = sqlserver_mod |
|
0 commit comments