Skip to content

Commit 66b2bbc

Browse files
Implemented cross checking for queried type in union [Generic Reference], to remove unwanted db querying for resolving reference.
1 parent 5d54b01 commit 66b2bbc

File tree

2 files changed

+71
-42
lines changed

2 files changed

+71
-42
lines changed

graphene_mongo/converter.py

Lines changed: 69 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -107,40 +107,52 @@ def convert_field_to_list(field, registry=None):
107107
if isinstance(base_type, graphene.Field):
108108
if isinstance(field.field, mongoengine.GenericReferenceField):
109109
def get_reference_objects(*args, **kwargs):
110-
if args[0][1]:
111-
document = get_document(args[0][0])
112-
document_field = mongoengine.ReferenceField(document)
113-
document_field = convert_mongoengine_field(document_field, registry)
114-
document_field_type = document_field.get_type().type._meta.name
115-
required_fields = [to_snake_case(i) for i in
110+
document = get_document(args[0][0])
111+
document_field = mongoengine.ReferenceField(document)
112+
document_field = convert_mongoengine_field(document_field, registry)
113+
document_field_type = document_field.get_type().type._meta.name
114+
required_fields = [to_snake_case(i) for i in
116115
get_query_fields(args[0][3][0])[document_field_type].keys()]
117-
return document.objects().no_dereference().only(*required_fields).filter(pk__in=args[0][1])
118-
else:
119-
return []
116+
return document.objects().no_dereference().only(*required_fields).filter(pk__in=args[0][1])
117+
118+
def get_non_querying_object(*args, **kwargs):
119+
model = get_document(args[0][0])
120+
return [model(pk=each) for each in args[0][1]]
120121

121122
def reference_resolver(root, *args, **kwargs):
122-
choice_to_resolve = dict()
123123
to_resolve = getattr(root, field.name or field.db_name)
124124
if to_resolve:
125+
choice_to_resolve = dict()
126+
querying_union_types = list(get_query_fields(args[0]).keys())
127+
if '__typename' in querying_union_types:
128+
querying_union_types.remove('__typename')
129+
to_resolve_models = list()
130+
for each in querying_union_types:
131+
to_resolve_models.append(registry._registry_string_map[each])
125132
for each in to_resolve:
126133
if each['_cls'] not in choice_to_resolve:
127134
choice_to_resolve[each['_cls']] = list()
128135
choice_to_resolve[each['_cls']].append(each["_ref"].id)
129-
130136
pool = ThreadPoolExecutor(5)
131137
futures = list()
132138
for model, object_id_list in choice_to_resolve.items():
133-
futures.append(pool.submit(get_reference_objects, (model, object_id_list, registry, args)))
139+
if model in to_resolve_models:
140+
futures.append(pool.submit(get_reference_objects, (model, object_id_list, registry, args)))
141+
else:
142+
futures.append(
143+
pool.submit(get_non_querying_object, (model, object_id_list, registry, args)))
134144
result = list()
135145
for x in as_completed(futures):
136146
result += x.result()
137147
to_resolve_object_ids = [each["_ref"].id for each in to_resolve]
138-
result_to_resolve_object_ids = [each.id for each in result]
148+
result_object_ids = list()
149+
for each in result:
150+
result_object_ids.append(each.id)
139151
ordered_result = list()
140152
for each in to_resolve_object_ids:
141-
ordered_result.append(result[result_to_resolve_object_ids.index(each)])
153+
ordered_result.append(result[result_object_ids.index(each)])
142154
return ordered_result
143-
return []
155+
return None
144156

145157
return graphene.List(
146158
base_type._type,
@@ -207,21 +219,29 @@ def convert_field_to_union(field, registry=None):
207219
_union = type(name, (graphene.Union,), {"Meta": Meta})
208220

209221
def reference_resolver(root, *args, **kwargs):
210-
dereferenced = getattr(root, field.name or field.db_name)
211-
if dereferenced:
212-
document = get_document(dereferenced["_cls"])
222+
de_referenced = getattr(root, field.name or field.db_name)
223+
if de_referenced:
224+
document = get_document(de_referenced["_cls"])
213225
document_field = mongoengine.ReferenceField(document)
214226
document_field = convert_mongoengine_field(document_field, registry)
215227
_type = document_field.get_type().type
216-
only_fields = _type._meta.only_fields.split(",") if isinstance(_type._meta.only_fields,
217-
str) else list()
218-
return document.objects().no_dereference().only(*list(
219-
set(only_fields + [to_snake_case(i) for i in get_query_fields(args[0])[_type._meta.name].keys()]))).get(
220-
pk=dereferenced["_ref"].id)
228+
querying_types = list(get_query_fields(args[0]).keys())
229+
_type = document_field.get_type().type
230+
if _type.__name__ in querying_types:
231+
return document.objects().no_dereference().only(*list(
232+
set(list(_type._meta.required_fields) + [to_snake_case(i) for i in
233+
get_query_fields(args[0])[_type._meta.name].keys()]))).get(
234+
pk=de_referenced["_ref"].id)
235+
return document
221236
return None
222237

223238
if isinstance(field, mongoengine.GenericReferenceField):
224-
return graphene.Field(_union, resolver=reference_resolver,
239+
field_resolver = None
240+
if field.db_field is not None:
241+
resolver_function = getattr(_union, "resolve_" + field.db_field, None)
242+
if resolver_function and callable(resolver_function):
243+
field_resolver = resolver_function
244+
return graphene.Field(_union, resolver=field_resolver if field_resolver else reference_resolver,
225245
description=get_field_description(field, registry))
226246

227247
return graphene.Field(_union)
@@ -237,36 +257,40 @@ def reference_resolver(root, *args, **kwargs):
237257
document = getattr(root, field.name or field.db_name)
238258
if document:
239259
_type = registry.get_type_for_model(field.document_type)
240-
required_fields = _type._meta.required_fields.split(",") if isinstance(_type._meta.required_fields,
241-
str) else list()
242260
return field.document_type.objects().no_dereference().only(
243-
*((list(set(required_fields + [to_snake_case(i) for i in get_query_fields(args[0]).keys()]))))).get(
261+
*((list(set(list(_type._meta.required_fields) + [to_snake_case(i) for i in
262+
get_query_fields(args[0]).keys()]))))).get(
244263
pk=document.id)
245264
return None
246265

247266
def cached_reference_resolver(root, *args, **kwargs):
248267
if field:
249268
_type = registry.get_type_for_model(field.document_type)
250-
required_fields = _type._meta.required_fields.split(",") if isinstance(_type._meta.required_fields,
251-
str) else list()
252269
return field.document_type.objects().no_dereference().only(
253-
*(list(set(required_fields + [to_snake_case(i) for i in get_query_fields(args[0]).keys()]))
254-
)).get(
270+
*(list(set(
271+
list(_type._meta.required_fields) + [to_snake_case(i) for i in
272+
get_query_fields(args[0]).keys()])))).get(
255273
pk=getattr(root, field.name or field.db_name))
256274
return None
257275

258276
def dynamic_type():
259277
_type = registry.get_type_for_model(model)
260278
if not _type:
261279
return None
262-
elif isinstance(field, mongoengine.ReferenceField):
263-
return graphene.Field(_type, resolver=reference_resolver,
280+
if isinstance(field, mongoengine.EmbeddedDocumentField):
281+
return graphene.Field(_type,
264282
description=get_field_description(field, registry))
265-
elif isinstance(field, mongoengine.CachedReferenceField):
266-
return graphene.Field(_type, resolver=cached_reference_resolver,
283+
field_resolver = None
284+
if field.db_field is not None:
285+
resolver_function = getattr(_type, "resolve_" + field.db_field, None)
286+
if resolver_function and callable(resolver_function):
287+
field_resolver = resolver_function
288+
if isinstance(field, mongoengine.ReferenceField):
289+
return graphene.Field(_type, resolver=field_resolver if field_resolver else reference_resolver,
290+
description=get_field_description(field, registry))
291+
else:
292+
return graphene.Field(_type, resolver=field_resolver if field_resolver else cached_reference_resolver(),
267293
description=get_field_description(field, registry))
268-
return graphene.Field(_type,
269-
description=get_field_description(field, registry))
270294

271295
return graphene.Dynamic(dynamic_type)
272296

@@ -279,21 +303,24 @@ def lazy_resolver(root, *args, **kwargs):
279303
document = getattr(root, field.name or field.db_name)
280304
if document:
281305
_type = registry.get_type_for_model(document.document_type)
282-
required_fields = _type._meta.required_fields.split(",") if isinstance(_type._meta.required_fields,
283-
str) else list()
284306
return document.document_type.objects().no_dereference().only(
285-
*(list(set((required_fields + [to_snake_case(i) for i in get_query_fields(args[0]).keys()]))))).get(
307+
*(list(set((list(_type._meta.required_fields) + [to_snake_case(i) for i in
308+
get_query_fields(args[0]).keys()]))))).get(
286309
pk=document.pk)
287310
return None
288311

289312
def dynamic_type():
290313
_type = registry.get_type_for_model(model)
291-
292314
if not _type:
293315
return None
316+
field_resolver = None
317+
if field.db_field is not None:
318+
resolver_function = getattr(_type, "resolve_" + field.db_field, None)
319+
if resolver_function and callable(resolver_function):
320+
field_resolver = resolver_function
294321
return graphene.Field(
295322
_type,
296-
resolver=lazy_resolver,
323+
resolver=field_resolver if field_resolver else lazy_resolver,
297324
description=get_field_description(field, registry),
298325
)
299326

graphene_mongo/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
class Registry(object):
22
def __init__(self):
33
self._registry = {}
4+
self._registry_string_map = {}
45

56
def register(self, cls):
67
from .types import MongoengineObjectType
@@ -12,6 +13,7 @@ def register(self, cls):
1213
)
1314
assert cls._meta.registry == self, "Registry for a Model have to match."
1415
self._registry[cls._meta.model] = cls
16+
self._registry_string_map[cls.__name__] = cls._meta.model.__name__
1517

1618
# Rescan all fields
1719
for model, cls in self._registry.items():

0 commit comments

Comments
 (0)