@@ -132,7 +132,7 @@ def get_transform(self, name):
132
132
if transform :
133
133
return transform
134
134
field = self .embedded_model ._meta .get_field (name )
135
- return KeyTransformFactory ( name , field )
135
+ return EmbeddedModelTransformFactory ( field )
136
136
137
137
def validate (self , value , model_instance ):
138
138
super ().validate (value , model_instance )
@@ -156,39 +156,40 @@ def formfield(self, **kwargs):
156
156
)
157
157
158
158
159
- class KeyTransform (Transform ):
160
- def __init__ (self , key_name , ref_field , * args , ** kwargs ):
159
+ class EmbeddedModelTransform (Transform ):
160
+ def __init__ (self , field , * args , ** kwargs ):
161
161
super ().__init__ (* args , ** kwargs )
162
- self .key_name = str (key_name )
163
- self .ref_field = ref_field
162
+ # self.field aliases self._field via BaseExpression.field returning
163
+ # self.output_field.
164
+ self ._field = field
164
165
165
166
def get_lookup (self , name ):
166
- return self .ref_field .get_lookup (name )
167
+ return self .field .get_lookup (name )
167
168
168
169
def get_transform (self , name ):
169
170
"""
170
171
Validate that `name` is either a field of an embedded model or a
171
172
lookup on an embedded model's field.
172
173
"""
173
- if transform := self .ref_field .get_transform (name ):
174
+ if transform := self .field .get_transform (name ):
174
175
return transform
175
- suggested_lookups = difflib .get_close_matches (name , self .ref_field .get_lookups ())
176
+ suggested_lookups = difflib .get_close_matches (name , self .field .get_lookups ())
176
177
if suggested_lookups :
177
178
suggested_lookups = " or " .join (suggested_lookups )
178
179
suggestion = f", perhaps you meant { suggested_lookups } ?"
179
180
else :
180
181
suggestion = "."
181
182
raise FieldDoesNotExist (
182
183
f"Unsupported lookup '{ name } ' for "
183
- f"{ self .ref_field .__class__ .__name__ } '{ self .ref_field .name } '"
184
+ f"{ self .field .__class__ .__name__ } '{ self .field .name } '"
184
185
f"{ suggestion } "
185
186
)
186
187
187
188
def as_mql (self , compiler , connection , as_path = False ):
188
189
previous = self
189
190
columns = []
190
- while isinstance (previous , KeyTransform ):
191
- columns .insert (0 , previous .ref_field .column )
191
+ while isinstance (previous , EmbeddedModelTransform ):
192
+ columns .insert (0 , previous .field .column )
192
193
previous = previous .lhs
193
194
if as_path :
194
195
mql = previous .as_mql (compiler , connection , as_path = True )
@@ -201,13 +202,12 @@ def as_mql(self, compiler, connection, as_path=False):
201
202
202
203
@property
203
204
def output_field (self ):
204
- return self .ref_field
205
+ return self ._field
205
206
206
207
207
- class KeyTransformFactory :
208
- def __init__ (self , key_name , ref_field ):
209
- self .key_name = key_name
210
- self .ref_field = ref_field
208
+ class EmbeddedModelTransformFactory :
209
+ def __init__ (self , field ):
210
+ self .field = field
211
211
212
212
def __call__ (self , * args , ** kwargs ):
213
- return KeyTransform (self .key_name , self . ref_field , * args , ** kwargs )
213
+ return EmbeddedModelTransform (self .field , * args , ** kwargs )
0 commit comments