33from django .core .exceptions import FieldDoesNotExist
44from django .db import models
55from django .db .models .expressions import Col
6- from django .db .models .lookups import Transform
6+ from django .db .models .lookups import Lookup , Transform
77
88from ..forms import EmbeddedModelArrayFormField
99from ..query_utils import process_lhs , process_rhs
1010from . import EmbeddedModelField
1111from .array import ArrayField
12- from .embedded_model import EMFExact
12+ from .embedded_model import EMFExact , EMFMixin
1313
1414
1515class EmbeddedModelArrayField (ArrayField ):
@@ -52,17 +52,8 @@ def get_transform(self, name):
5252 return KeyTransformFactory (name , self )
5353
5454
55- class ProcessRHSMixin :
56- def process_rhs (self , compiler , connection ):
57- if isinstance (self .lhs , KeyTransform ):
58- get_db_prep_value = self .lhs ._lhs .output_field .get_db_prep_value
59- else :
60- get_db_prep_value = self .lhs .output_field .get_db_prep_value
61- return None , [get_db_prep_value (v , connection , prepared = True ) for v in self .rhs ]
62-
63-
6455@EmbeddedModelArrayField .register_lookup
65- class EMFArrayExact (EMFExact , ProcessRHSMixin ):
56+ class EMFArrayExact (EMFExact ):
6657 def as_mql (self , compiler , connection ):
6758 lhs_mql = process_lhs (self , compiler , connection )
6859 value = process_rhs (self , compiler , connection )
@@ -105,12 +96,29 @@ def as_mql(self, compiler, connection):
10596
10697
10798@EmbeddedModelArrayField .register_lookup
108- class ArrayOverlap (EMFExact , ProcessRHSMixin ):
99+ class ArrayOverlap (EMFMixin , Lookup ):
109100 lookup_name = "overlap"
101+ get_db_prep_lookup_value_is_iterable = True
102+
103+ def process_rhs (self , compiler , connection ):
104+ values = self .rhs
105+ if self .get_db_prep_lookup_value_is_iterable :
106+ values = [values ]
107+ # Compute how to serialize each value based on the query target.
108+ # If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
109+ # field of the subfield. Otherwise, use the base field of the array itself.
110+ if isinstance (self .lhs , KeyTransform ):
111+ get_db_prep_value = self .lhs ._lhs .output_field .get_db_prep_value
112+ else :
113+ get_db_prep_value = self .lhs .output_field .base_field .get_db_prep_value
114+ return None , [get_db_prep_value (v , connection , prepared = True ) for v in values ]
110115
111116 def as_mql (self , compiler , connection ):
112117 lhs_mql = process_lhs (self , compiler , connection )
113118 values = process_rhs (self , compiler , connection )
119+ # Querying a subfield within the array elements (via nested KeyTransform).
120+ # Replicates MongoDB's implicit ANY-match by mapping over the array and applying
121+ # `$in` on the subfield.
114122 if isinstance (self .lhs , KeyTransform ):
115123 lhs_mql , inner_lhs_mql = lhs_mql
116124 return {
@@ -129,11 +137,12 @@ def as_mql(self, compiler, connection):
129137 }
130138 conditions = []
131139 inner_lhs_mql = "$$item"
140+ # Querying full embedded documents in the array.
141+ # Builds `$or` conditions and maps them over the array to match any full document.
132142 for value in values :
133- if isinstance (value , models .Model ):
134- value , emf_data = self .model_to_dict (value )
135- # Get conditions for any nested EmbeddedModelFields.
136- conditions .append ({"$and" : self .get_conditions ({inner_lhs_mql : (value , emf_data )})})
143+ value , emf_data = self .model_to_dict (value )
144+ # Get conditions for any nested EmbeddedModelFields.
145+ conditions .append ({"$and" : self .get_conditions ({inner_lhs_mql : (value , emf_data )})})
137146 return {
138147 "$anyElementTrue" : {
139148 "$ifNull" : [
0 commit comments