Skip to content

Commit c64c759

Browse files
Arun S KumarArun S Kumar
authored andcommitted
Query efficiency and performance - Retrieving only the queried fields from database
1 parent 5ad41a3 commit c64c759

File tree

3 files changed

+97
-26
lines changed

3 files changed

+97
-26
lines changed

graphene_mongo/converter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from mongoengine.base import get_document
77

88
from . import advanced_types
9-
from .utils import import_single_dispatch, get_field_description
9+
from .utils import import_single_dispatch, get_field_description, get_query_fields
1010

1111
singledispatch = import_single_dispatch()
1212

@@ -186,7 +186,9 @@ def convert_lazy_field_to_dynamic(field, registry=None):
186186

187187
def lazy_resolver(root, *args, **kwargs):
188188
if getattr(root, field.name or field.db_name):
189-
return getattr(root, field.name or field.db_name).fetch()
189+
only_fields = get_query_fields(args[0]).keys()
190+
document = getattr(root, field.name or field.db_name)
191+
return document.document_type.objects().only(*only_fields).get(pk=document.pk)
190192

191193
def dynamic_type():
192194
_type = registry.get_type_for_model(model)

graphene_mongo/fields.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import graphene
77
import mongoengine
8+
from graphql.utils.ast_to_dict import ast_to_dict
89
from promise import Promise
910
from graphql_relay import from_global_id
1011
from graphene.relay import ConnectionField
@@ -21,7 +22,7 @@
2122
)
2223
from .converter import convert_mongoengine_field, MongoEngineConversionError
2324
from .registry import get_global_registry
24-
from .utils import get_model_reference_fields, get_node_from_global_id
25+
from .utils import get_model_reference_fields, get_node_from_global_id, get_query_fields, camel_to_snake
2526

2627

2728
class MongoengineConnectionField(ConnectionField):
@@ -98,18 +99,18 @@ def is_filterable(k):
9899
if isinstance(converted, (ConnectionField, Dynamic)):
99100
return False
100101
if callable(getattr(converted, "type", None)) and isinstance(
101-
converted.type(),
102-
(
103-
FileFieldType,
104-
PointFieldType,
105-
MultiPolygonFieldType,
106-
graphene.Union,
107-
PolygonFieldType,
108-
),
102+
converted.type(),
103+
(
104+
FileFieldType,
105+
PointFieldType,
106+
MultiPolygonFieldType,
107+
graphene.Union,
108+
PolygonFieldType,
109+
),
109110
):
110111
return False
111112
if isinstance(converted, (graphene.List)) and issubclass(
112-
getattr(converted, "_of_type", None), graphene.Union
113+
getattr(converted, "_of_type", None), graphene.Union
113114
):
114115
return False
115116

@@ -160,16 +161,16 @@ def get_reference_field(r, kv):
160161
field = kv[1]
161162
mongo_field = getattr(self.model, kv[0], None)
162163
if isinstance(
163-
mongo_field,
164-
(mongoengine.LazyReferenceField, mongoengine.ReferenceField),
164+
mongo_field,
165+
(mongoengine.LazyReferenceField, mongoengine.ReferenceField),
165166
):
166167
field = convert_mongoengine_field(mongo_field, self.registry)
167168
if callable(getattr(field, "get_type", None)):
168169
_type = field.get_type()
169170
if _type:
170171
node = _type._type._meta
171172
if "id" in node.fields and not issubclass(
172-
node.model, (mongoengine.EmbeddedDocument,)
173+
node.model, (mongoengine.EmbeddedDocument,)
173174
):
174175
r.update({kv[0]: node.fields["id"]._type.of_type()})
175176
return r
@@ -180,7 +181,7 @@ def get_reference_field(r, kv):
180181
def fields(self):
181182
return self._type._meta.fields
182183

183-
def get_queryset(self, model, info, **args):
184+
def get_queryset(self, model, info, only_fields=list(), **args):
184185
if args:
185186
reference_fields = get_model_reference_fields(self.model)
186187
hydrated_references = {}
@@ -198,12 +199,13 @@ def get_queryset(self, model, info, **args):
198199
return queryset_or_filters
199200
else:
200201
args.update(queryset_or_filters)
202+
201203
return model.objects(**args).order_by(self.order_by)
202204

203-
def default_resolver(self, _root, info, **args):
205+
def default_resolver(self, _root, info, only_fields=list(), **args):
204206
args = args or {}
205207

206-
if _root is not None:
208+
if _root is not None and getattr(_root, info.field_name, []) is not None:
207209
args["pk__in"] = [r.pk for r in getattr(_root, info.field_name, [])]
208210

209211
connection_args = {
@@ -219,7 +221,7 @@ def default_resolver(self, _root, info, **args):
219221
args['pk'] = from_global_id(_id)[-1]
220222

221223
if callable(getattr(self.model, "objects", None)):
222-
iterables = self.get_queryset(self.model, info, **args)
224+
iterables = self.get_queryset(self.model, info, only_fields, **args)
223225
list_length = iterables.count()
224226
else:
225227
iterables = []
@@ -239,23 +241,31 @@ def default_resolver(self, _root, info, **args):
239241
return connection
240242

241243
def chained_resolver(self, resolver, is_partial, root, info, **args):
244+
only_fields = list()
245+
for field in get_query_fields(info):
246+
if camel_to_snake(field) in self.model._fields_ordered:
247+
only_fields.append(camel_to_snake(field))
242248
if not bool(args) or not is_partial:
249+
if isinstance(self.model, mongoengine.Document) or isinstance(self.model,
250+
mongoengine.base.metaclasses.TopLevelDocumentMetaclass):
251+
args_copy = args.copy()
252+
for arg_name, arg in args.copy().items():
253+
if arg_name not in self.model._fields_ordered:
254+
args_copy.pop(arg_name)
243255
# XXX: Filter nested args
244256
resolved = resolver(root, info, **args)
245257
if resolved is not None:
246258
return resolved
247-
return self.default_resolver(root, info, **args)
259+
return self.default_resolver(root, info, only_fields, **args)
248260

249261
@classmethod
250262
def connection_resolver(cls, resolver, connection_type, root, info, **args):
251263
iterable = resolver(root, info, **args)
252264
if isinstance(connection_type, graphene.NonNull):
253265
connection_type = connection_type.of_type
254-
255266
on_resolve = partial(cls.resolve_connection, connection_type, args)
256267
if Promise.is_thenable(iterable):
257268
return Promise.resolve(iterable).then(on_resolve)
258-
259269
return on_resolve(iterable)
260270

261271
def get_resolver(self, parent_resolver):

graphene_mongo/utils.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import mongoengine
77
from graphene import Node
88
from graphene.utils.trim_docstring import trim_docstring
9+
from graphql.utils.ast_to_dict import ast_to_dict
910

1011

1112
def get_model_fields(model, excluding=None):
@@ -23,8 +24,8 @@ def get_model_reference_fields(model, excluding=None):
2324
attributes = dict()
2425
for attr_name, attr in model._fields.items():
2526
if attr_name in excluding or not isinstance(
26-
attr,
27-
(mongoengine.fields.ReferenceField, mongoengine.fields.LazyReferenceField),
27+
attr,
28+
(mongoengine.fields.ReferenceField, mongoengine.fields.LazyReferenceField),
2829
):
2930
continue
3031
attributes[attr_name] = attr
@@ -33,8 +34,8 @@ def get_model_reference_fields(model, excluding=None):
3334

3435
def is_valid_mongoengine_model(model):
3536
return inspect.isclass(model) and (
36-
issubclass(model, mongoengine.Document)
37-
or issubclass(model, mongoengine.EmbeddedDocument)
37+
issubclass(model, mongoengine.Document)
38+
or issubclass(model, mongoengine.EmbeddedDocument)
3839
)
3940

4041

@@ -101,3 +102,61 @@ def get_node_from_global_id(node, info, global_id):
101102
return interface.get_node_from_global_id(info, global_id)
102103
except AttributeError:
103104
return Node.get_node_from_global_id(info, global_id)
105+
106+
107+
def collect_query_fields(node, fragments):
108+
"""Recursively collects fields from the AST
109+
110+
Args:
111+
node (dict): A node in the AST
112+
fragments (dict): Fragment definitions
113+
114+
Returns:
115+
A dict mapping each field found, along with their sub fields.
116+
117+
{'name': {},
118+
'image': {'id': {},
119+
'name': {},
120+
'description': {}},
121+
'slug': {}}
122+
"""
123+
124+
field = {}
125+
126+
if node.get('selection_set'):
127+
for leaf in node['selection_set']['selections']:
128+
if leaf['kind'] == 'Field':
129+
field.update({
130+
leaf['name']['value']: collect_query_fields(leaf, fragments)
131+
})
132+
elif leaf['kind'] == 'FragmentSpread':
133+
field.update(collect_query_fields(fragments[leaf['name']['value']],
134+
fragments))
135+
136+
return field
137+
138+
139+
def get_query_fields(info):
140+
"""A convenience function to call collect_query_fields with info
141+
142+
Args:
143+
info (ResolveInfo)
144+
145+
Returns:
146+
dict: Returned from collect_query_fields
147+
"""
148+
149+
fragments = {}
150+
node = ast_to_dict(info.field_asts[0])
151+
152+
for name, value in info.fragments.items():
153+
fragments[name] = ast_to_dict(value)
154+
155+
query = collect_query_fields(node, fragments)
156+
if "edges" in query:
157+
return query["edges"]["node"].keys()
158+
return query
159+
160+
161+
def camel_to_snake(field):
162+
return ''.join(['_' + c.lower() if c.isupper() else c for c in field]).lstrip('_')

0 commit comments

Comments
 (0)