Skip to content

Commit 1bad8e5

Browse files
support filter agrs in default reference resolvers
1 parent 75ba65e commit 1bad8e5

File tree

1 file changed

+39
-11
lines changed

1 file changed

+39
-11
lines changed

graphene_mongo/converter.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,14 @@ def get_reference_objects(*args, **kwargs):
112112
document_field = convert_mongoengine_field(document_field, registry)
113113
document_field_type = document_field.get_type().type
114114
queried_fields = list()
115+
filter_args = list()
116+
if document_field_type._meta.filter_fields:
117+
for key, values in document_field_type._meta.filter_fields.items():
118+
for each in values:
119+
filter_args.append(key + "__" + each)
115120
for each in get_query_fields(args[0][3][0])[document_field_type._meta.name].keys():
116121
item = to_snake_case(each)
117-
if item in document._fields_ordered:
122+
if item in document._fields_ordered + tuple(filter_args):
118123
queried_fields.append(item)
119124
return document.objects().no_dereference().only(
120125
*set(list(document_field_type._meta.required_fields) + queried_fields)).filter(pk__in=args[0][1])
@@ -229,12 +234,17 @@ def reference_resolver(root, *args, **kwargs):
229234
document_field = mongoengine.ReferenceField(document)
230235
document_field = convert_mongoengine_field(document_field, registry)
231236
_type = document_field.get_type().type
237+
filter_args = list()
238+
if _type._meta.filter_fields:
239+
for key, values in _type._meta.filter_fields.items():
240+
for each in values:
241+
filter_args.append(key + "__" + each)
232242
querying_types = list(get_query_fields(args[0]).keys())
233243
if _type.__name__ in querying_types:
234244
queried_fields = list()
235245
for each in get_query_fields(args[0]).keys():
236246
item = to_snake_case(each)
237-
if item in document._fields_ordered:
247+
if item in document._fields_ordered + tuple(filter_args):
238248
queried_fields.append(item)
239249
return document.objects().no_dereference().only(*list(
240250
set(list(_type._meta.required_fields) + queried_fields))).get(
@@ -247,7 +257,8 @@ def reference_resolver(root, *args, **kwargs):
247257
required = False
248258
if field.db_field is not None:
249259
required = field.required
250-
resolver_function = getattr(registry.get_type_for_model(field.owner_document), "resolve_" + field.db_field, None)
260+
resolver_function = getattr(registry.get_type_for_model(field.owner_document), "resolve_" + field.db_field,
261+
None)
251262
if resolver_function and callable(resolver_function):
252263
field_resolver = resolver_function
253264
return graphene.Field(_union, resolver=field_resolver if field_resolver else reference_resolver,
@@ -266,11 +277,16 @@ def reference_resolver(root, *args, **kwargs):
266277
document = getattr(root, field.name or field.db_name)
267278
if document:
268279
queried_fields = list()
280+
_type = registry.get_type_for_model(field.document_type)
281+
filter_args = list()
282+
if _type._meta.filter_fields:
283+
for key, values in _type._meta.filter_fields.items():
284+
for each in values:
285+
filter_args.append(key + "__" + each)
269286
for each in get_query_fields(args[0]).keys():
270287
item = to_snake_case(each)
271-
if item in field.document_type._fields_ordered:
288+
if item in field.document_type._fields_ordered + tuple(filter_args):
272289
queried_fields.append(item)
273-
_type = registry.get_type_for_model(field.document_type)
274290
return field.document_type.objects().no_dereference().only(
275291
*(set(list(_type._meta.required_fields) + queried_fields))).get(
276292
pk=document.id)
@@ -279,11 +295,16 @@ def reference_resolver(root, *args, **kwargs):
279295
def cached_reference_resolver(root, *args, **kwargs):
280296
if field:
281297
queried_fields = list()
298+
_type = registry.get_type_for_model(field.document_type)
299+
filter_args = list()
300+
if _type._meta.filter_fields:
301+
for key, values in _type._meta.filter_fields.items():
302+
for each in values:
303+
filter_args.append(key + "__" + each)
282304
for each in get_query_fields(args[0]).keys():
283305
item = to_snake_case(each)
284-
if item in field.document_type._fields_ordered:
306+
if item in field.document_type._fields_ordered + tuple(filter_args):
285307
queried_fields.append(item)
286-
_type = registry.get_type_for_model(field.document_type)
287308
return field.document_type.objects().no_dereference().only(
288309
*(set(
289310
list(_type._meta.required_fields) + queried_fields))).get(
@@ -301,7 +322,8 @@ def dynamic_type():
301322
required = False
302323
if field.db_field is not None:
303324
required = field.required
304-
resolver_function = getattr(registry.get_type_for_model(field.owner_document), "resolve_" + field.db_field, None)
325+
resolver_function = getattr(registry.get_type_for_model(field.owner_document), "resolve_" + field.db_field,
326+
None)
305327
if resolver_function and callable(resolver_function):
306328
field_resolver = resolver_function
307329
if isinstance(field, mongoengine.ReferenceField):
@@ -322,11 +344,16 @@ def lazy_resolver(root, *args, **kwargs):
322344
document = getattr(root, field.name or field.db_name)
323345
if document:
324346
queried_fields = list()
347+
_type = registry.get_type_for_model(document.document_type)
348+
filter_args = list()
349+
if _type._meta.filter_fields:
350+
for key, values in _type._meta.filter_fields.items():
351+
for each in values:
352+
filter_args.append(key + "__" + each)
325353
for each in get_query_fields(args[0]).keys():
326354
item = to_snake_case(each)
327-
if item in document.document_type._fields_ordered:
355+
if item in document.document_type._fields_ordered + tuple(filter_args):
328356
queried_fields.append(item)
329-
_type = registry.get_type_for_model(document.document_type)
330357
return document.document_type.objects().no_dereference().only(
331358
*(set((list(_type._meta.required_fields) + queried_fields)))).get(
332359
pk=document.pk)
@@ -340,7 +367,8 @@ def dynamic_type():
340367
required = False
341368
if field.db_field is not None:
342369
required = field.required
343-
resolver_function = getattr(registry.get_type_for_model(field.owner_document), "resolve_" + field.db_field, None)
370+
resolver_function = getattr(registry.get_type_for_model(field.owner_document), "resolve_" + field.db_field,
371+
None)
344372
if resolver_function and callable(resolver_function):
345373
field_resolver = resolver_function
346374
return graphene.Field(

0 commit comments

Comments
 (0)