Skip to content

Commit 7a72922

Browse files
committed
Simplify and rename EmbeddedModelField's transform classes
1 parent 6d5537c commit 7a72922

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

django_mongodb_backend/fields/embedded_model.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def get_transform(self, name):
132132
if transform:
133133
return transform
134134
field = self.embedded_model._meta.get_field(name)
135-
return KeyTransformFactory(name, field)
135+
return EmbeddedModelTransformFactory(field)
136136

137137
def validate(self, value, model_instance):
138138
super().validate(value, model_instance)
@@ -156,39 +156,40 @@ def formfield(self, **kwargs):
156156
)
157157

158158

159-
class KeyTransform(Transform):
160-
def __init__(self, key_name, ref_field, *args, **kwargs):
159+
class EmbeddedModelTransform(Transform):
160+
def __init__(self, field, *args, **kwargs):
161161
super().__init__(*args, **kwargs)
162-
self.key_name = str(key_name)
163-
self.ref_field = ref_field
162+
# self.field aliases self._field via BaseExpression.field returning
163+
# self.output_field.
164+
self._field = field
164165

165166
def get_lookup(self, name):
166-
return self.ref_field.get_lookup(name)
167+
return self.field.get_lookup(name)
167168

168169
def get_transform(self, name):
169170
"""
170171
Validate that `name` is either a field of an embedded model or a
171172
lookup on an embedded model's field.
172173
"""
173-
if transform := self.ref_field.get_transform(name):
174+
if transform := self.field.get_transform(name):
174175
return transform
175-
suggested_lookups = difflib.get_close_matches(name, self.ref_field.get_lookups())
176+
suggested_lookups = difflib.get_close_matches(name, self.field.get_lookups())
176177
if suggested_lookups:
177178
suggested_lookups = " or ".join(suggested_lookups)
178179
suggestion = f", perhaps you meant {suggested_lookups}?"
179180
else:
180181
suggestion = "."
181182
raise FieldDoesNotExist(
182183
f"Unsupported lookup '{name}' for "
183-
f"{self.ref_field.__class__.__name__} '{self.ref_field.name}'"
184+
f"{self.field.__class__.__name__} '{self.field.name}'"
184185
f"{suggestion}"
185186
)
186187

187188
def as_mql(self, compiler, connection, as_path=False):
188189
previous = self
189190
columns = []
190-
while isinstance(previous, KeyTransform):
191-
columns.insert(0, previous.ref_field.column)
191+
while isinstance(previous, EmbeddedModelTransform):
192+
columns.insert(0, previous.field.column)
192193
previous = previous.lhs
193194
if as_path:
194195
mql = previous.as_mql(compiler, connection, as_path=True)
@@ -201,13 +202,12 @@ def as_mql(self, compiler, connection, as_path=False):
201202

202203
@property
203204
def output_field(self):
204-
return self.ref_field
205+
return self._field
205206

206207

207-
class KeyTransformFactory:
208-
def __init__(self, key_name, ref_field):
209-
self.key_name = key_name
210-
self.ref_field = ref_field
208+
class EmbeddedModelTransformFactory:
209+
def __init__(self, field):
210+
self.field = field
211211

212212
def __call__(self, *args, **kwargs):
213-
return KeyTransform(self.key_name, self.ref_field, *args, **kwargs)
213+
return EmbeddedModelTransform(self.field, *args, **kwargs)

django_mongodb_backend/fields/polymorphic_embedded_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from django.db import models
66
from django.db.models.fields.related import lazy_related_operation
77

8-
from .embedded_model import KeyTransformFactory
8+
from .embedded_model import EmbeddedModelTransformFactory
99
from .utils import get_mongodb_connection
1010

1111

@@ -170,7 +170,7 @@ def get_transform(self, name):
170170
raise FieldDoesNotExist(
171171
f"The models of field '{self.name}' have no field named '{name}'."
172172
)
173-
return KeyTransformFactory(name, field)
173+
return EmbeddedModelTransformFactory(field)
174174

175175
def validate(self, value, model_instance):
176176
super().validate(value, model_instance)

0 commit comments

Comments
 (0)