Skip to content

Commit ab2ed59

Browse files
Merge branch 'feat-retrieving-queried-fields-only' into feat-pagination-performance
# Conflicts: # graphene_mongo/converter.py
2 parents cd21053 + 66b2bbc commit ab2ed59

File tree

2 files changed

+48
-34
lines changed

2 files changed

+48
-34
lines changed

graphene_mongo/converter.py

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -107,40 +107,52 @@ def convert_field_to_list(field, registry=None):
107107
if isinstance(base_type, graphene.Field):
108108
if isinstance(field.field, mongoengine.GenericReferenceField):
109109
def get_reference_objects(*args, **kwargs):
110-
if args[0][1]:
111-
document = get_document(args[0][0])
112-
document_field = mongoengine.ReferenceField(document)
113-
document_field = convert_mongoengine_field(document_field, registry)
114-
document_field_type = document_field.get_type().type._meta.name
115-
required_fields = [to_snake_case(i) for i in
116-
get_query_fields(args[0][3][0])[document_field_type].keys()]
117-
return document.objects().no_dereference().only(*required_fields).filter(pk__in=args[0][1])
118-
else:
119-
return []
110+
document = get_document(args[0][0])
111+
document_field = mongoengine.ReferenceField(document)
112+
document_field = convert_mongoengine_field(document_field, registry)
113+
document_field_type = document_field.get_type().type._meta.name
114+
required_fields = [to_snake_case(i) for i in
115+
get_query_fields(args[0][3][0])[document_field_type].keys()]
116+
return document.objects().no_dereference().only(*required_fields).filter(pk__in=args[0][1])
117+
118+
def get_non_querying_object(*args, **kwargs):
119+
model = get_document(args[0][0])
120+
return [model(pk=each) for each in args[0][1]]
120121

121122
def reference_resolver(root, *args, **kwargs):
122-
choice_to_resolve = dict()
123123
to_resolve = getattr(root, field.name or field.db_name)
124124
if to_resolve:
125+
choice_to_resolve = dict()
126+
querying_union_types = list(get_query_fields(args[0]).keys())
127+
if '__typename' in querying_union_types:
128+
querying_union_types.remove('__typename')
129+
to_resolve_models = list()
130+
for each in querying_union_types:
131+
to_resolve_models.append(registry._registry_string_map[each])
125132
for each in to_resolve:
126133
if each['_cls'] not in choice_to_resolve:
127134
choice_to_resolve[each['_cls']] = list()
128135
choice_to_resolve[each['_cls']].append(each["_ref"].id)
129-
130136
pool = ThreadPoolExecutor(5)
131137
futures = list()
132138
for model, object_id_list in choice_to_resolve.items():
133-
futures.append(pool.submit(get_reference_objects, (model, object_id_list, registry, args)))
139+
if model in to_resolve_models:
140+
futures.append(pool.submit(get_reference_objects, (model, object_id_list, registry, args)))
141+
else:
142+
futures.append(
143+
pool.submit(get_non_querying_object, (model, object_id_list, registry, args)))
134144
result = list()
135145
for x in as_completed(futures):
136146
result += x.result()
137147
to_resolve_object_ids = [each["_ref"].id for each in to_resolve]
138-
result_to_resolve_object_ids = [each.id for each in result]
148+
result_object_ids = list()
149+
for each in result:
150+
result_object_ids.append(each.id)
139151
ordered_result = list()
140152
for each in to_resolve_object_ids:
141-
ordered_result.append(result[result_to_resolve_object_ids.index(each)])
153+
ordered_result.append(result[result_object_ids.index(each)])
142154
return ordered_result
143-
return []
155+
return None
144156

145157
return graphene.List(
146158
base_type._type,
@@ -207,17 +219,20 @@ def convert_field_to_union(field, registry=None):
207219
_union = type(name, (graphene.Union,), {"Meta": Meta})
208220

209221
def reference_resolver(root, *args, **kwargs):
210-
dereferenced = getattr(root, field.name or field.db_name)
211-
if dereferenced:
212-
document = get_document(dereferenced["_cls"])
222+
de_referenced = getattr(root, field.name or field.db_name)
223+
if de_referenced:
224+
document = get_document(de_referenced["_cls"])
213225
document_field = mongoengine.ReferenceField(document)
214226
document_field = convert_mongoengine_field(document_field, registry)
215227
_type = document_field.get_type().type
216-
only_fields = _type._meta.only_fields.split(",") if isinstance(_type._meta.only_fields,
217-
str) else list()
218-
return document.objects().no_dereference().only(*list(
219-
set(only_fields + [to_snake_case(i) for i in get_query_fields(args[0])[_type._meta.name].keys()]))).get(
220-
pk=dereferenced["_ref"].id)
228+
querying_types = list(get_query_fields(args[0]).keys())
229+
_type = document_field.get_type().type
230+
if _type.__name__ in querying_types:
231+
return document.objects().no_dereference().only(*list(
232+
set(list(_type._meta.required_fields) + [to_snake_case(i) for i in
233+
get_query_fields(args[0])[_type._meta.name].keys()]))).get(
234+
pk=de_referenced["_ref"].id)
235+
return document
221236
return None
222237

223238
if isinstance(field, mongoengine.GenericReferenceField):
@@ -242,21 +257,19 @@ def reference_resolver(root, *args, **kwargs):
242257
document = getattr(root, field.name or field.db_name)
243258
if document:
244259
_type = registry.get_type_for_model(field.document_type)
245-
required_fields = _type._meta.required_fields.split(",") if isinstance(_type._meta.required_fields,
246-
str) else list()
247260
return field.document_type.objects().no_dereference().only(
248-
*((list(set(required_fields + [to_snake_case(i) for i in get_query_fields(args[0]).keys()]))))).get(
261+
*((list(set(list(_type._meta.required_fields) + [to_snake_case(i) for i in
262+
get_query_fields(args[0]).keys()]))))).get(
249263
pk=document.id)
250264
return None
251265

252266
def cached_reference_resolver(root, *args, **kwargs):
253267
if field:
254268
_type = registry.get_type_for_model(field.document_type)
255-
required_fields = _type._meta.required_fields.split(",") if isinstance(_type._meta.required_fields,
256-
str) else list()
257269
return field.document_type.objects().no_dereference().only(
258-
*(list(set(required_fields + [to_snake_case(i) for i in get_query_fields(args[0]).keys()]))
259-
)).get(
270+
*(list(set(
271+
list(_type._meta.required_fields) + [to_snake_case(i) for i in
272+
get_query_fields(args[0]).keys()])))).get(
260273
pk=getattr(root, field.name or field.db_name))
261274
return None
262275

@@ -290,10 +303,9 @@ def lazy_resolver(root, *args, **kwargs):
290303
document = getattr(root, field.name or field.db_name)
291304
if document:
292305
_type = registry.get_type_for_model(document.document_type)
293-
required_fields = _type._meta.required_fields.split(",") if isinstance(_type._meta.required_fields,
294-
str) else list()
295306
return document.document_type.objects().no_dereference().only(
296-
*(list(set((required_fields + [to_snake_case(i) for i in get_query_fields(args[0]).keys()]))))).get(
307+
*(list(set((list(_type._meta.required_fields) + [to_snake_case(i) for i in
308+
get_query_fields(args[0]).keys()]))))).get(
297309
pk=document.pk)
298310
return None
299311

graphene_mongo/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
class Registry(object):
22
def __init__(self):
33
self._registry = {}
4+
self._registry_string_map = {}
45

56
def register(self, cls):
67
from .types import MongoengineObjectType
@@ -12,6 +13,7 @@ def register(self, cls):
1213
)
1314
assert cls._meta.registry == self, "Registry for a Model have to match."
1415
self._registry[cls._meta.model] = cls
16+
self._registry_string_map[cls.__name__] = cls._meta.model.__name__
1517

1618
# Rescan all fields
1719
for model, cls in self._registry.items():

0 commit comments

Comments
 (0)