@@ -69,55 +69,56 @@ class EMFArrayExact(EMFExact):
69
69
def as_mql (self , compiler , connection ):
70
70
lhs_mql = process_lhs (self , compiler , connection )
71
71
value = process_rhs (self , compiler , connection )
72
- if isinstance (self .lhs , Col | KeyTransform ):
73
- if isinstance (self .lhs , Col ):
74
- inner_lhs_mql = "$$item"
75
- else :
76
- lhs_mql , inner_lhs_mql = lhs_mql
77
- if isinstance (value , models .Model ):
78
- value , emf_data = self .model_to_dict (value )
79
- # Get conditions for any nested EmbeddedModelFields.
80
- conditions = self .get_conditions ({inner_lhs_mql : (value , emf_data )})
81
- return {
82
- "$anyElementTrue" : {
83
- "$ifNull" : [
84
- {
85
- "$map" : {
86
- "input" : lhs_mql ,
87
- "as" : "item" ,
88
- "in" : {"$and" : conditions },
89
- }
90
- },
91
- [],
92
- ]
93
- }
94
- }
72
+ if isinstance (self .lhs , KeyTransform ):
73
+ lhs_mql , inner_lhs_mql = lhs_mql
74
+ else :
75
+ inner_lhs_mql = "$$item"
76
+ if isinstance (value , models .Model ):
77
+ value , emf_data = self .model_to_dict (value )
78
+ # Get conditions for any nested EmbeddedModelFields.
79
+ conditions = self .get_conditions ({inner_lhs_mql : (value , emf_data )})
95
80
return {
96
81
"$anyElementTrue" : {
97
82
"$ifNull" : [
98
83
{
99
84
"$map" : {
100
85
"input" : lhs_mql ,
101
86
"as" : "item" ,
102
- "in" : {"$eq " : [ inner_lhs_mql , value ] },
87
+ "in" : {"$and " : conditions },
103
88
}
104
89
},
105
90
[],
106
91
]
107
92
}
108
93
}
109
- return connection .mongo_operators [self .lookup_name ](lhs_mql , value )
94
+ return {
95
+ "$anyElementTrue" : {
96
+ "$ifNull" : [
97
+ {
98
+ "$map" : {
99
+ "input" : lhs_mql ,
100
+ "as" : "item" ,
101
+ "in" : {"$eq" : [inner_lhs_mql , value ]},
102
+ }
103
+ },
104
+ [],
105
+ ]
106
+ }
107
+ }
110
108
111
109
112
110
class KeyTransform (Transform ):
113
111
# it should be different class than EMF keytransform even most of the methods are equal.
114
112
def __init__ (self , key_name , base_field , * args , ** kwargs ):
115
113
super ().__init__ (* args , ** kwargs )
116
114
self .base_field = base_field
117
- # TODO: Need to create a column, will refactor this thing.
115
+ self .key_name = key_name
116
+ # The iteration items begins from the base_field, a virtual column with
117
+ # base field output type is created.
118
118
column_target = base_field .clone ()
119
- column_target .db_column = f"$item.{ key_name } "
120
- column_target .set_attributes_from_name (f"$item.{ key_name } " )
119
+ column_name = f"$item.{ key_name } "
120
+ column_target .db_column = column_name
121
+ column_target .set_attributes_from_name (column_name )
121
122
self ._lhs = Col (None , column_target )
122
123
self ._sub_transform = None
123
124
@@ -128,19 +129,8 @@ def __call__(self, this, *args, **kwargs):
128
129
def get_lookup (self , name ):
129
130
return self .output_field .get_lookup (name )
130
131
131
- def get_transform (self , name ):
132
- """
133
- Validate that `name` is either a field of an embedded model or a
134
- lookup on an embedded model's field.
135
- """
136
- if isinstance (self ._lhs , Transform ):
137
- transform = self ._lhs .get_transform (name )
138
- else :
139
- transform = self .base_field .get_transform (name )
140
- if transform :
141
- self ._sub_transform = transform
142
- return self
143
- suggested_lookups = difflib .get_close_matches (name , self .base_field .get_lookups ())
132
+ def _get_missing_field_or_lookup_exception (self , lhs , name ):
133
+ suggested_lookups = difflib .get_close_matches (name , lhs .get_lookups ())
144
134
if suggested_lookups :
145
135
suggested_lookups = " or " .join (suggested_lookups )
146
136
suggestion = f", perhaps you meant { suggested_lookups } ?"
@@ -152,6 +142,25 @@ def get_transform(self, name):
152
142
f"{ suggestion } "
153
143
)
154
144
145
+ def get_transform (self , name ):
146
+ """
147
+ Validate that `name` is either a field of an embedded model or a
148
+ lookup on an embedded model's field.
149
+ """
150
+ # Once the sub lhs is a transform, all the filter are applied over it.
151
+
152
+ transform = (
153
+ self ._lhs .get_transform (name )
154
+ if isinstance (self ._lhs , Transform )
155
+ else self .base_field .get_transform (name )
156
+ )
157
+ if transform :
158
+ self ._sub_transform = transform
159
+ return self
160
+ raise self ._get_missing_field_or_lookup_exception (
161
+ self ._lhs if isinstance (self ._lhs , Transform ) else self .base_field , name
162
+ )
163
+
155
164
def as_mql (self , compiler , connection ):
156
165
inner_lhs_mql = self ._lhs .as_mql (compiler , connection )
157
166
lhs_mql = process_lhs (self , compiler , connection )
0 commit comments