24
24
FileFieldType ,
25
25
PointFieldType ,
26
26
MultiPolygonFieldType ,
27
- PolygonFieldType ,
27
+ PolygonFieldType , PointFieldInputType ,
28
28
)
29
29
from .converter import convert_mongoengine_field , MongoEngineConversionError
30
30
from .registry import get_global_registry
@@ -79,7 +79,7 @@ def registry(self):
79
79
def args (self ):
80
80
return to_arguments (
81
81
self ._base_args or OrderedDict (),
82
- dict (dict (self .field_args , ** self .reference_args ), ** self .filter_args ),
82
+ dict (dict (self .field_args , ** self .advance_args ), ** self .filter_args ),
83
83
)
84
84
85
85
@args .setter
@@ -149,35 +149,42 @@ def filter_args(self):
149
149
if self ._type ._meta .filter_fields :
150
150
for field , filter_collection in self ._type ._meta .filter_fields .items ():
151
151
for each in filter_collection :
152
- filter_type = getattr (
153
- graphene ,
154
- str (self ._type ._meta .fields [field ].type ).replace ("!" , "" ),
155
- )
156
-
152
+ if str (self ._type ._meta .fields [field ].type ) == 'PointFieldType' :
153
+ if each == 'max_distance' :
154
+ filter_type = graphene .Int
155
+ else :
156
+ filter_type = PointFieldInputType
157
+ else :
158
+ filter_type = getattr (
159
+ graphene ,
160
+ str (self ._type ._meta .fields [field ].type ).replace ("!" , "" ),
161
+ )
157
162
# handle special cases
158
163
advanced_filter_types = {
159
164
"in" : graphene .List (filter_type ),
160
165
"nin" : graphene .List (filter_type ),
161
166
"all" : graphene .List (filter_type ),
162
167
}
163
-
164
168
filter_type = advanced_filter_types .get (each , filter_type )
165
169
filter_args [field + "__" + each ] = graphene .Argument (
166
170
type = filter_type
167
171
)
168
-
169
172
return filter_args
170
173
171
174
@property
172
- def reference_args (self ):
173
- def get_reference_field (r , kv ):
175
+ def advance_args (self ):
176
+ def get_advance_field (r , kv ):
174
177
field = kv [1 ]
175
178
mongo_field = getattr (self .model , kv [0 ], None )
179
+ if isinstance (mongo_field , mongoengine .PointField ):
180
+ r .update ({kv [0 ]: graphene .Argument (PointFieldInputType )})
181
+ return r
176
182
if isinstance (
177
183
mongo_field ,
178
- (mongoengine .LazyReferenceField , mongoengine .ReferenceField ),
184
+ (mongoengine .LazyReferenceField , mongoengine .ReferenceField , mongoengine . GenericReferenceField ),
179
185
):
180
- field = convert_mongoengine_field (mongo_field , self .registry )
186
+ r .update ({kv [0 ]: graphene .ID ()})
187
+ return r
181
188
if isinstance (mongo_field , mongoengine .GenericReferenceField ):
182
189
r .update ({kv [0 ]: graphene .ID ()})
183
190
return r
@@ -192,7 +199,7 @@ def get_reference_field(r, kv):
192
199
193
200
return r
194
201
195
- return reduce (get_reference_field , self .fields .items (), {})
202
+ return reduce (get_advance_field , self .fields .items (), {})
196
203
197
204
@property
198
205
def fields (self ):
@@ -220,6 +227,12 @@ def get_queryset(self, model, info, required_fields=list(), skip=None, limit=Non
220
227
reference_obj = get_document (arg ["_cls" ])(
221
228
pk = arg ["_ref" ].id )
222
229
hydrated_references [arg_name ] = reference_obj
230
+ elif '__near' in arg_name and isinstance (getattr (self .model , arg_name .split ('__' )[0 ]),
231
+ mongoengine .fields .PointField ):
232
+ location = args .pop (arg_name , None )
233
+ hydrated_references [arg_name ] = location ["coordinates" ]
234
+ if (arg_name .split ('__' )[0 ] + "__max_distance" ) not in args :
235
+ hydrated_references [arg_name .split ('__' )[0 ] + "__max_distance" ] = 10000
223
236
elif arg_name == "id" :
224
237
hydrated_references ["id" ] = from_global_id (args .pop ("id" , None ))[1 ]
225
238
args .update (hydrated_references )
@@ -381,10 +394,17 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
381
394
self .filter_args .keys ()):
382
395
args_copy .pop (arg_name )
383
396
if arg_name == '_id' and isinstance (arg , dict ):
384
- args_copy ['pk__in' ] = arg ['$in' ]
397
+ operation = list (arg .keys ())[0 ]
398
+ args_copy ['pk' + operation .replace ('$' , '__' )] = arg [operation ]
385
399
if '.' in arg_name :
386
400
operation = list (arg .keys ())[0 ]
387
401
args_copy [arg_name .replace ('.' , '__' ) + operation .replace ('$' , '__' )] = arg [operation ]
402
+ else :
403
+ operations = ["$lte" , "$gte" , "$ne" , "$in" ]
404
+ if isinstance (arg , dict ) and any (op in arg for op in operations ):
405
+ operation = list (arg .keys ())[0 ]
406
+ args_copy [arg_name + operation .replace ('$' , '__' )] = arg [operation ]
407
+ del args_copy [arg_name ]
388
408
return self .default_resolver (root , info , required_fields , ** args_copy )
389
409
else :
390
410
return resolved
0 commit comments