Skip to content

Commit 8f6a4c0

Browse files
Implemented cross checking for defined resolver before assigning default resolver for Reference, Generic Reference & Lazy Reference fields
1 parent 5d54b01 commit 8f6a4c0

File tree

1 file changed

+28
-13
lines changed

1 file changed

+28
-13
lines changed

graphene_mongo/converter.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def get_reference_objects(*args, **kwargs):
113113
document_field = convert_mongoengine_field(document_field, registry)
114114
document_field_type = document_field.get_type().type._meta.name
115115
required_fields = [to_snake_case(i) for i in
116-
get_query_fields(args[0][3][0])[document_field_type].keys()]
116+
get_query_fields(args[0][3][0])[document_field_type].keys()]
117117
return document.objects().no_dereference().only(*required_fields).filter(pk__in=args[0][1])
118118
else:
119119
return []
@@ -221,7 +221,12 @@ def reference_resolver(root, *args, **kwargs):
221221
return None
222222

223223
if isinstance(field, mongoengine.GenericReferenceField):
224-
return graphene.Field(_union, resolver=reference_resolver,
224+
field_resolver = None
225+
if field.db_field is not None:
226+
resolver_function = getattr(_union, "resolve_" + field.db_field, None)
227+
if resolver_function and callable(resolver_function):
228+
field_resolver = resolver_function
229+
return graphene.Field(_union, resolver=field_resolver if field_resolver else reference_resolver,
225230
description=get_field_description(field, registry))
226231

227232
return graphene.Field(_union)
@@ -238,7 +243,7 @@ def reference_resolver(root, *args, **kwargs):
238243
if document:
239244
_type = registry.get_type_for_model(field.document_type)
240245
required_fields = _type._meta.required_fields.split(",") if isinstance(_type._meta.required_fields,
241-
str) else list()
246+
str) else list()
242247
return field.document_type.objects().no_dereference().only(
243248
*((list(set(required_fields + [to_snake_case(i) for i in get_query_fields(args[0]).keys()]))))).get(
244249
pk=document.id)
@@ -248,7 +253,7 @@ def cached_reference_resolver(root, *args, **kwargs):
248253
if field:
249254
_type = registry.get_type_for_model(field.document_type)
250255
required_fields = _type._meta.required_fields.split(",") if isinstance(_type._meta.required_fields,
251-
str) else list()
256+
str) else list()
252257
return field.document_type.objects().no_dereference().only(
253258
*(list(set(required_fields + [to_snake_case(i) for i in get_query_fields(args[0]).keys()]))
254259
)).get(
@@ -259,14 +264,20 @@ def dynamic_type():
259264
_type = registry.get_type_for_model(model)
260265
if not _type:
261266
return None
262-
elif isinstance(field, mongoengine.ReferenceField):
263-
return graphene.Field(_type, resolver=reference_resolver,
267+
if isinstance(field, mongoengine.EmbeddedDocumentField):
268+
return graphene.Field(_type,
264269
description=get_field_description(field, registry))
265-
elif isinstance(field, mongoengine.CachedReferenceField):
266-
return graphene.Field(_type, resolver=cached_reference_resolver,
270+
field_resolver = None
271+
if field.db_field is not None:
272+
resolver_function = getattr(_type, "resolve_" + field.db_field, None)
273+
if resolver_function and callable(resolver_function):
274+
field_resolver = resolver_function
275+
if isinstance(field, mongoengine.ReferenceField):
276+
return graphene.Field(_type, resolver=field_resolver if field_resolver else reference_resolver,
277+
description=get_field_description(field, registry))
278+
else:
279+
return graphene.Field(_type, resolver=field_resolver if field_resolver else cached_reference_resolver(),
267280
description=get_field_description(field, registry))
268-
return graphene.Field(_type,
269-
description=get_field_description(field, registry))
270281

271282
return graphene.Dynamic(dynamic_type)
272283

@@ -280,20 +291,24 @@ def lazy_resolver(root, *args, **kwargs):
280291
if document:
281292
_type = registry.get_type_for_model(document.document_type)
282293
required_fields = _type._meta.required_fields.split(",") if isinstance(_type._meta.required_fields,
283-
str) else list()
294+
str) else list()
284295
return document.document_type.objects().no_dereference().only(
285296
*(list(set((required_fields + [to_snake_case(i) for i in get_query_fields(args[0]).keys()]))))).get(
286297
pk=document.pk)
287298
return None
288299

289300
def dynamic_type():
290301
_type = registry.get_type_for_model(model)
291-
292302
if not _type:
293303
return None
304+
field_resolver = None
305+
if field.db_field is not None:
306+
resolver_function = getattr(_type, "resolve_" + field.db_field, None)
307+
if resolver_function and callable(resolver_function):
308+
field_resolver = resolver_function
294309
return graphene.Field(
295310
_type,
296-
resolver=lazy_resolver,
311+
resolver=field_resolver if field_resolver else lazy_resolver,
297312
description=get_field_description(field, registry),
298313
)
299314

0 commit comments

Comments
 (0)