Skip to content

Commit 5d08133

Browse files
authored
Merge pull request #156 from arunsureshkumar/feat-retrieving-queried-fields-only
Feat retrieving queried fields only
2 parents 7ca0925 + 74e6094 commit 5d08133

File tree

9 files changed

+274
-51
lines changed

9 files changed

+274
-51
lines changed

examples/flask_mongoengine/database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from mongoengine import connect
22

3-
from models import Department, Employee, Role, Task
3+
from .models import Department, Employee, Role, Task
44

55
connect("graphene-mongo-example", host="mongomock://localhost", alias="default")
66

examples/flask_mongoengine/schema.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import graphene
22
from graphene.relay import Node
3+
from graphene_mongo.tests.nodes import PlayerNode, ReporterNode
4+
35
from graphene_mongo import MongoengineConnectionField, MongoengineObjectType
4-
from models import Department as DepartmentModel
5-
from models import Employee as EmployeeModel
6-
from models import Role as RoleModel
7-
from models import Task as TaskModel
6+
from .models import Department as DepartmentModel
7+
from .models import Employee as EmployeeModel
8+
from .models import Role as RoleModel
9+
from .models import Task as TaskModel
810

911

1012
class Department(MongoengineObjectType):
@@ -17,6 +19,9 @@ class Role(MongoengineObjectType):
1719
class Meta:
1820
model = RoleModel
1921
interfaces = (Node,)
22+
filter_fields = {
23+
'name': ['exact', 'icontains', 'istartswith']
24+
}
2025

2126

2227
class Task(MongoengineObjectType):
@@ -29,6 +34,9 @@ class Employee(MongoengineObjectType):
2934
class Meta:
3035
model = EmployeeModel
3136
interfaces = (Node,)
37+
filter_fields = {
38+
'name': ['exact', 'icontains', 'istartswith']
39+
}
3240

3341

3442
class Query(graphene.ObjectType):

graphene_mongo/converter.py

Lines changed: 106 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import uuid
44

55
from graphene.types.json import JSONString
6+
from graphene.utils.str_converters import to_snake_case
67
from mongoengine.base import get_document
7-
88
from . import advanced_types
9-
from .utils import import_single_dispatch, get_field_description
9+
from .utils import import_single_dispatch, get_field_description, get_query_fields
10+
from concurrent.futures import ThreadPoolExecutor, as_completed
1011

1112
singledispatch = import_single_dispatch()
1213

@@ -104,6 +105,49 @@ def convert_file_to_field(field, registry=None):
104105
def convert_field_to_list(field, registry=None):
105106
base_type = convert_mongoengine_field(field.field, registry=registry)
106107
if isinstance(base_type, graphene.Field):
108+
if isinstance(field.field, mongoengine.GenericReferenceField):
109+
def get_reference_objects(*args, **kwargs):
110+
if args[0][1]:
111+
document = get_document(args[0][0])
112+
document_field = mongoengine.ReferenceField(document)
113+
document_field = convert_mongoengine_field(document_field, registry)
114+
document_field_type = document_field.get_type().type._meta.name
115+
only_fields = [to_snake_case(i) for i in
116+
get_query_fields(args[0][3][0])[document_field_type].keys()]
117+
return document.objects().no_dereference().only(*only_fields).filter(pk__in=args[0][1])
118+
else:
119+
return []
120+
121+
def reference_resolver(root, *args, **kwargs):
122+
choice_to_resolve = dict()
123+
to_resolve = getattr(root, field.name or field.db_name)
124+
if to_resolve:
125+
for each in to_resolve:
126+
if each['_cls'] not in choice_to_resolve:
127+
choice_to_resolve[each['_cls']] = list()
128+
choice_to_resolve[each['_cls']].append(each["_ref"].id)
129+
130+
pool = ThreadPoolExecutor(5)
131+
futures = list()
132+
for model, object_id_list in choice_to_resolve.items():
133+
futures.append(pool.submit(get_reference_objects, (model, object_id_list, registry, args)))
134+
result = list()
135+
for x in as_completed(futures):
136+
result += x.result()
137+
to_resolve_object_ids = [each["_ref"].id for each in to_resolve]
138+
result_to_resolve_object_ids = [each.id for each in result]
139+
ordered_result = list()
140+
for each in to_resolve_object_ids:
141+
ordered_result.append(result[result_to_resolve_object_ids.index(each)])
142+
return ordered_result
143+
return []
144+
145+
return graphene.List(
146+
base_type._type,
147+
description=get_field_description(field, registry),
148+
required=field.required,
149+
resolver=reference_resolver
150+
)
107151
return graphene.List(
108152
base_type._type,
109153
description=get_field_description(field, registry),
@@ -121,7 +165,7 @@ def convert_field_to_list(field, registry=None):
121165
# Non-relationship field
122166
relations = (mongoengine.ReferenceField, mongoengine.EmbeddedDocumentField)
123167
if not isinstance(base_type, (graphene.List, graphene.NonNull)) and not isinstance(
124-
field.field, relations
168+
field.field, relations
125169
):
126170
base_type = type(base_type)
127171

@@ -135,7 +179,6 @@ def convert_field_to_list(field, registry=None):
135179
@convert_mongoengine_field.register(mongoengine.GenericEmbeddedDocumentField)
136180
@convert_mongoengine_field.register(mongoengine.GenericReferenceField)
137181
def convert_field_to_union(field, registry=None):
138-
139182
_types = []
140183
for choice in field.choices:
141184
if isinstance(field, mongoengine.GenericReferenceField):
@@ -162,6 +205,25 @@ def convert_field_to_union(field, registry=None):
162205
)
163206
Meta = type("Meta", (object,), {"types": tuple(_types)})
164207
_union = type(name, (graphene.Union,), {"Meta": Meta})
208+
209+
def reference_resolver(root, *args, **kwargs):
210+
dereferenced = getattr(root, field.name or field.db_name)
211+
if dereferenced:
212+
document = get_document(dereferenced["_cls"])
213+
document_field = mongoengine.ReferenceField(document)
214+
document_field = convert_mongoengine_field(document_field, registry)
215+
_type = document_field.get_type().type
216+
only_fields = _type._meta.only_fields.split(",") if isinstance(_type._meta.only_fields,
217+
str) else list()
218+
return document.objects().no_dereference().only(*list(
219+
set(only_fields + [to_snake_case(i) for i in get_query_fields(args[0])[_type._meta.name].keys()]))).get(
220+
pk=dereferenced["_ref"].id)
221+
return None
222+
223+
if isinstance(field, mongoengine.GenericReferenceField):
224+
return graphene.Field(_union, resolver=reference_resolver,
225+
description=get_field_description(field, registry))
226+
165227
return graphene.Field(_union)
166228

167229

@@ -171,11 +233,40 @@ def convert_field_to_union(field, registry=None):
171233
def convert_field_to_dynamic(field, registry=None):
172234
model = field.document_type
173235

236+
def reference_resolver(root, *args, **kwargs):
237+
document = getattr(root, field.name or field.db_name)
238+
if document:
239+
_type = registry.get_type_for_model(field.document_type)
240+
only_fields = _type._meta.only_fields.split(",") if isinstance(_type._meta.only_fields,
241+
str) else list()
242+
return field.document_type.objects().no_dereference().only(
243+
*((list(set(only_fields + [to_snake_case(i) for i in get_query_fields(args[0]).keys()]))))).get(
244+
pk=document.id)
245+
return None
246+
247+
def cached_reference_resolver(root, *args, **kwargs):
248+
if field:
249+
_type = registry.get_type_for_model(field.document_type)
250+
only_fields = _type._meta.only_fields.split(",") if isinstance(_type._meta.only_fields,
251+
str) else list()
252+
return field.document_type.objects().no_dereference().only(
253+
*(list(set(only_fields + [to_snake_case(i) for i in get_query_fields(args[0]).keys()]))
254+
)).get(
255+
pk=getattr(root, field.name or field.db_name))
256+
return None
257+
174258
def dynamic_type():
175259
_type = registry.get_type_for_model(model)
176260
if not _type:
177261
return None
178-
return graphene.Field(_type, description=get_field_description(field, registry))
262+
elif isinstance(field, mongoengine.ReferenceField):
263+
return graphene.Field(_type, resolver=reference_resolver,
264+
description=get_field_description(field, registry))
265+
elif isinstance(field, mongoengine.CachedReferenceField):
266+
return graphene.Field(_type, resolver=cached_reference_resolver,
267+
description=get_field_description(field, registry))
268+
return graphene.Field(_type,
269+
description=get_field_description(field, registry))
179270

180271
return graphene.Dynamic(dynamic_type)
181272

@@ -185,11 +276,19 @@ def convert_lazy_field_to_dynamic(field, registry=None):
185276
model = field.document_type
186277

187278
def lazy_resolver(root, *args, **kwargs):
188-
if getattr(root, field.name or field.db_name):
189-
return getattr(root, field.name or field.db_name).fetch()
279+
document = getattr(root, field.name or field.db_name)
280+
if document:
281+
_type = registry.get_type_for_model(document.document_type)
282+
only_fields = _type._meta.only_fields.split(",") if isinstance(_type._meta.only_fields,
283+
str) else list()
284+
return document.document_type.objects().no_dereference().only(
285+
*(list(set((only_fields + [to_snake_case(i) for i in get_query_fields(args[0]).keys()]))))).get(
286+
pk=document.pk)
287+
return None
190288

191289
def dynamic_type():
192290
_type = registry.get_type_for_model(model)
291+
193292
if not _type:
194293
return None
195294
return graphene.Field(

graphene_mongo/fields.py

Lines changed: 60 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55

66
import graphene
77
import mongoengine
8+
from bson import DBRef
9+
from graphene import Context
10+
from graphene.utils.str_converters import to_snake_case
11+
from graphql import ResolveInfo
812
from promise import Promise
913
from graphql_relay import from_global_id
1014
from graphene.relay import ConnectionField
@@ -21,7 +25,7 @@
2125
)
2226
from .converter import convert_mongoengine_field, MongoEngineConversionError
2327
from .registry import get_global_registry
24-
from .utils import get_model_reference_fields, get_node_from_global_id
28+
from .utils import get_model_reference_fields, get_node_from_global_id, get_query_fields
2529

2630

2731
class MongoengineConnectionField(ConnectionField):
@@ -59,6 +63,12 @@ def model(self):
5963
def order_by(self):
6064
return self.node_type._meta.order_by
6165

66+
@property
67+
def only_fields(self):
68+
if isinstance(self.node_type._meta.only_fields, str):
69+
return self.node_type._meta.only_fields.split(",")
70+
return list()
71+
6272
@property
6373
def registry(self):
6474
return getattr(self.node_type._meta, "registry", get_global_registry())
@@ -98,18 +108,18 @@ def is_filterable(k):
98108
if isinstance(converted, (ConnectionField, Dynamic)):
99109
return False
100110
if callable(getattr(converted, "type", None)) and isinstance(
101-
converted.type(),
102-
(
103-
FileFieldType,
104-
PointFieldType,
105-
MultiPolygonFieldType,
106-
graphene.Union,
107-
PolygonFieldType,
108-
),
111+
converted.type(),
112+
(
113+
FileFieldType,
114+
PointFieldType,
115+
MultiPolygonFieldType,
116+
graphene.Union,
117+
PolygonFieldType,
118+
),
109119
):
110120
return False
111121
if isinstance(converted, (graphene.List)) and issubclass(
112-
getattr(converted, "_of_type", None), graphene.Union
122+
getattr(converted, "_of_type", None), graphene.Union
113123
):
114124
return False
115125

@@ -160,16 +170,16 @@ def get_reference_field(r, kv):
160170
field = kv[1]
161171
mongo_field = getattr(self.model, kv[0], None)
162172
if isinstance(
163-
mongo_field,
164-
(mongoengine.LazyReferenceField, mongoengine.ReferenceField),
173+
mongo_field,
174+
(mongoengine.LazyReferenceField, mongoengine.ReferenceField),
165175
):
166176
field = convert_mongoengine_field(mongo_field, self.registry)
167177
if callable(getattr(field, "get_type", None)):
168178
_type = field.get_type()
169179
if _type:
170180
node = _type._type._meta
171181
if "id" in node.fields and not issubclass(
172-
node.model, (mongoengine.EmbeddedDocument,)
182+
node.model, (mongoengine.EmbeddedDocument,)
173183
):
174184
r.update({kv[0]: node.fields["id"]._type.of_type()})
175185
return r
@@ -180,7 +190,7 @@ def get_reference_field(r, kv):
180190
def fields(self):
181191
return self._type._meta.fields
182192

183-
def get_queryset(self, model, info, **args):
193+
def get_queryset(self, model, info, only_fields=list(), **args):
184194
if args:
185195
reference_fields = get_model_reference_fields(self.model)
186196
hydrated_references = {}
@@ -198,13 +208,16 @@ def get_queryset(self, model, info, **args):
198208
return queryset_or_filters
199209
else:
200210
args.update(queryset_or_filters)
201-
return model.objects(**args).order_by(self.order_by)
202211

203-
def default_resolver(self, _root, info, **args):
212+
return model.objects(**args).no_dereference().only(*only_fields).order_by(self.order_by)
213+
214+
def default_resolver(self, _root, info, only_fields=list(), **args):
204215
args = args or {}
205216

206217
if _root is not None:
207-
args["pk__in"] = [r.pk for r in getattr(_root, info.field_name, [])]
218+
field_name = to_snake_case(info.field_name)
219+
if getattr(_root, field_name, []) is not None:
220+
args["pk__in"] = [r.id for r in getattr(_root, field_name, [])]
208221

209222
connection_args = {
210223
"first": args.pop("first", None),
@@ -219,7 +232,11 @@ def default_resolver(self, _root, info, **args):
219232
args['pk'] = from_global_id(_id)[-1]
220233

221234
if callable(getattr(self.model, "objects", None)):
222-
iterables = self.get_queryset(self.model, info, **args)
235+
iterables = self.get_queryset(self.model, info, only_fields, **args)
236+
if isinstance(info, ResolveInfo):
237+
if not info.context:
238+
info.context = Context()
239+
info.context.queryset = iterables
223240
list_length = iterables.count()
224241
else:
225242
iterables = []
@@ -239,23 +256,44 @@ def default_resolver(self, _root, info, **args):
239256
return connection
240257

241258
def chained_resolver(self, resolver, is_partial, root, info, **args):
259+
only_fields = list()
260+
for field in self.only_fields:
261+
if field in self.model._fields_ordered:
262+
only_fields.append(field)
263+
for field in get_query_fields(info):
264+
if to_snake_case(field) in self.model._fields_ordered:
265+
only_fields.append(to_snake_case(field))
242266
if not bool(args) or not is_partial:
267+
if isinstance(self.model, mongoengine.Document) or isinstance(self.model,
268+
mongoengine.base.metaclasses.TopLevelDocumentMetaclass):
269+
args_copy = args.copy()
270+
for arg_name, arg in args.copy().items():
271+
if arg_name not in self.model._fields_ordered:
272+
args_copy.pop(arg_name)
273+
if isinstance(info, ResolveInfo):
274+
if not info.context:
275+
info.context = Context()
276+
info.context.queryset = self.get_queryset(self.model, info, only_fields, **args_copy)
243277
# XXX: Filter nested args
244278
resolved = resolver(root, info, **args)
245279
if resolved is not None:
246-
return resolved
247-
return self.default_resolver(root, info, **args)
280+
if isinstance(resolved, list):
281+
if resolved == list():
282+
return resolved
283+
elif not isinstance(resolved[0], DBRef):
284+
return resolved
285+
else:
286+
return resolved
287+
return self.default_resolver(root, info, only_fields, **args)
248288

249289
@classmethod
250290
def connection_resolver(cls, resolver, connection_type, root, info, **args):
251291
iterable = resolver(root, info, **args)
252292
if isinstance(connection_type, graphene.NonNull):
253293
connection_type = connection_type.of_type
254-
255294
on_resolve = partial(cls.resolve_connection, connection_type, args)
256295
if Promise.is_thenable(iterable):
257296
return Promise.resolve(iterable).then(on_resolve)
258-
259297
return on_resolve(iterable)
260298

261299
def get_resolver(self, parent_resolver):

graphene_mongo/tests/test_relay_query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class Query(graphene.ObjectType):
1717
reporter = graphene.Field(nodes.ReporterNode)
1818

1919
def resolve_reporter(self, *args, **kwargs):
20-
return models.Reporter.objects.first()
20+
return models.Reporter.objects.no_dereference().first()
2121

2222
query = """
2323
query ReporterQuery {

0 commit comments

Comments
 (0)