@@ -66,55 +66,56 @@ class EMFArrayExact(EMFExact):
66
66
def as_mql (self , compiler , connection ):
67
67
lhs_mql = process_lhs (self , compiler , connection )
68
68
value = process_rhs (self , compiler , connection )
69
- if isinstance (self .lhs , Col | KeyTransform ):
70
- if isinstance (self .lhs , Col ):
71
- inner_lhs_mql = "$$item"
72
- else :
73
- lhs_mql , inner_lhs_mql = lhs_mql
74
- if isinstance (value , models .Model ):
75
- value , emf_data = self .model_to_dict (value )
76
- # Get conditions for any nested EmbeddedModelFields.
77
- conditions = self .get_conditions ({inner_lhs_mql : (value , emf_data )})
78
- return {
79
- "$anyElementTrue" : {
80
- "$ifNull" : [
81
- {
82
- "$map" : {
83
- "input" : lhs_mql ,
84
- "as" : "item" ,
85
- "in" : {"$and" : conditions },
86
- }
87
- },
88
- [],
89
- ]
90
- }
91
- }
69
+ if isinstance (self .lhs , KeyTransform ):
70
+ lhs_mql , inner_lhs_mql = lhs_mql
71
+ else :
72
+ inner_lhs_mql = "$$item"
73
+ if isinstance (value , models .Model ):
74
+ value , emf_data = self .model_to_dict (value )
75
+ # Get conditions for any nested EmbeddedModelFields.
76
+ conditions = self .get_conditions ({inner_lhs_mql : (value , emf_data )})
92
77
return {
93
78
"$anyElementTrue" : {
94
79
"$ifNull" : [
95
80
{
96
81
"$map" : {
97
82
"input" : lhs_mql ,
98
83
"as" : "item" ,
99
- "in" : {"$eq " : [ inner_lhs_mql , value ] },
84
+ "in" : {"$and " : conditions },
100
85
}
101
86
},
102
87
[],
103
88
]
104
89
}
105
90
}
106
- return connection .mongo_operators [self .lookup_name ](lhs_mql , value )
91
+ return {
92
+ "$anyElementTrue" : {
93
+ "$ifNull" : [
94
+ {
95
+ "$map" : {
96
+ "input" : lhs_mql ,
97
+ "as" : "item" ,
98
+ "in" : {"$eq" : [inner_lhs_mql , value ]},
99
+ }
100
+ },
101
+ [],
102
+ ]
103
+ }
104
+ }
107
105
108
106
109
107
class KeyTransform (Transform ):
110
108
# it should be different class than EMF keytransform even most of the methods are equal.
111
109
def __init__ (self , key_name , base_field , * args , ** kwargs ):
112
110
super ().__init__ (* args , ** kwargs )
113
111
self .base_field = base_field
114
- # TODO: Need to create a column, will refactor this thing.
112
+ self .key_name = key_name
113
+ # The iteration items begins from the base_field, a virtual column with
114
+ # base field output type is created.
115
115
column_target = base_field .clone ()
116
- column_target .db_column = f"$item.{ key_name } "
117
- column_target .set_attributes_from_name (f"$item.{ key_name } " )
116
+ column_name = f"$item.{ key_name } "
117
+ column_target .db_column = column_name
118
+ column_target .set_attributes_from_name (column_name )
118
119
self ._lhs = Col (None , column_target )
119
120
self ._sub_transform = None
120
121
@@ -125,19 +126,8 @@ def __call__(self, this, *args, **kwargs):
125
126
def get_lookup (self , name ):
126
127
return self .output_field .get_lookup (name )
127
128
128
- def get_transform (self , name ):
129
- """
130
- Validate that `name` is either a field of an embedded model or a
131
- lookup on an embedded model's field.
132
- """
133
- if isinstance (self ._lhs , Transform ):
134
- transform = self ._lhs .get_transform (name )
135
- else :
136
- transform = self .base_field .get_transform (name )
137
- if transform :
138
- self ._sub_transform = transform
139
- return self
140
- suggested_lookups = difflib .get_close_matches (name , self .base_field .get_lookups ())
129
+ def _get_missing_field_or_lookup_exception (self , lhs , name ):
130
+ suggested_lookups = difflib .get_close_matches (name , lhs .get_lookups ())
141
131
if suggested_lookups :
142
132
suggested_lookups = " or " .join (suggested_lookups )
143
133
suggestion = f", perhaps you meant { suggested_lookups } ?"
@@ -149,6 +139,25 @@ def get_transform(self, name):
149
139
f"{ suggestion } "
150
140
)
151
141
142
+ def get_transform (self , name ):
143
+ """
144
+ Validate that `name` is either a field of an embedded model or a
145
+ lookup on an embedded model's field.
146
+ """
147
+ # Once the sub lhs is a transform, all the filter are applied over it.
148
+
149
+ transform = (
150
+ self ._lhs .get_transform (name )
151
+ if isinstance (self ._lhs , Transform )
152
+ else self .base_field .get_transform (name )
153
+ )
154
+ if transform :
155
+ self ._sub_transform = transform
156
+ return self
157
+ raise self ._get_missing_field_or_lookup_exception (
158
+ self ._lhs if isinstance (self ._lhs , Transform ) else self .base_field , name
159
+ )
160
+
152
161
def as_mql (self , compiler , connection ):
153
162
inner_lhs_mql = self ._lhs .as_mql (compiler , connection )
154
163
lhs_mql = process_lhs (self , compiler , connection )
0 commit comments