1
+ import difflib
2
+
1
3
from django .core import checks
4
+ from django .core .exceptions import FieldDoesNotExist
2
5
from django .db import models
3
6
from django .db .models .fields .related import lazy_related_operation
4
7
from django .db .models .lookups import Transform
@@ -123,7 +126,8 @@ def get_transform(self, name):
123
126
transform = super ().get_transform (name )
124
127
if transform :
125
128
return transform
126
- return KeyTransformFactory (name )
129
+ field = self .embedded_model ._meta .get_field (name )
130
+ return KeyTransformFactory (name , field )
127
131
128
132
def validate (self , value , model_instance ):
129
133
super ().validate (value , model_instance )
@@ -145,9 +149,36 @@ def formfield(self, **kwargs):
145
149
146
150
147
151
class KeyTransform (Transform ):
148
- def __init__ (self , key_name , * args , ** kwargs ):
152
+ def __init__ (self , key_name , ref_field , * args , ** kwargs ):
149
153
super ().__init__ (* args , ** kwargs )
150
154
self .key_name = str (key_name )
155
+ self .ref_field = ref_field
156
+
157
+ def get_transform (self , name ):
158
+ """
159
+ Validate that `name` is either a field of an embedded model or a
160
+ lookup on an embedded model's field.
161
+ """
162
+ result = None
163
+ if isinstance (self .ref_field , EmbeddedModelField ):
164
+ opts = self .ref_field .embedded_model ._meta
165
+ new_field = opts .get_field (name )
166
+ result = KeyTransformFactory (name , new_field )
167
+ else :
168
+ if self .ref_field .get_transform (name ) is None :
169
+ suggested_lookups = difflib .get_close_matches (name , self .ref_field .get_lookups ())
170
+ if suggested_lookups :
171
+ suggested_lookups = " or " .join (suggested_lookups )
172
+ suggestion = f", perhaps you meant { suggested_lookups } ?"
173
+ else :
174
+ suggestion = "."
175
+ raise FieldDoesNotExist (
176
+ f"Unsupported lookup '{ name } ' for "
177
+ f"{ self .ref_field .__class__ .__name__ } '{ self .ref_field .name } '"
178
+ f"{ suggestion } "
179
+ )
180
+ result = KeyTransformFactory (name , self .ref_field )
181
+ return result
151
182
152
183
def preprocess_lhs (self , compiler , connection ):
153
184
key_transforms = [self .key_name ]
@@ -165,8 +196,9 @@ def as_mql(self, compiler, connection):
165
196
166
197
167
198
class KeyTransformFactory :
168
- def __init__ (self , key_name ):
199
+ def __init__ (self , key_name , ref_field ):
169
200
self .key_name = key_name
201
+ self .ref_field = ref_field
170
202
171
203
def __call__ (self , * args , ** kwargs ):
172
- return KeyTransform (self .key_name , * args , ** kwargs )
204
+ return KeyTransform (self .key_name , self . ref_field , * args , ** kwargs )
0 commit comments