Skip to content

Commit 1ff79aa

Browse files
Merge branch 'feat-retrieving-queried-fields-only' into feat-pagination-performance
2 parents faf834d + ed4b1a1 commit 1ff79aa

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

graphene_mongo/converter.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,14 @@ def reference_resolver(root, *args, **kwargs):
237237

238238
if isinstance(field, mongoengine.GenericReferenceField):
239239
field_resolver = None
240+
required = False
240241
if field.db_field is not None:
242+
required = field.required
241243
resolver_function = getattr(_union, "resolve_" + field.db_field, None)
242244
if resolver_function and callable(resolver_function):
243245
field_resolver = resolver_function
244246
return graphene.Field(_union, resolver=field_resolver if field_resolver else reference_resolver,
245-
description=get_field_description(field, registry))
247+
description=get_field_description(field, registry), required=required)
246248

247249
return graphene.Field(_union)
248250

@@ -281,16 +283,18 @@ def dynamic_type():
281283
return graphene.Field(_type,
282284
description=get_field_description(field, registry))
283285
field_resolver = None
286+
required = False
284287
if field.db_field is not None:
288+
required = field.required
285289
resolver_function = getattr(_type, "resolve_" + field.db_field, None)
286290
if resolver_function and callable(resolver_function):
287291
field_resolver = resolver_function
288292
if isinstance(field, mongoengine.ReferenceField):
289293
return graphene.Field(_type, resolver=field_resolver if field_resolver else reference_resolver,
290-
description=get_field_description(field, registry))
294+
description=get_field_description(field, registry), required=required)
291295
else:
292-
return graphene.Field(_type, resolver=field_resolver if field_resolver else cached_reference_resolver(),
293-
description=get_field_description(field, registry))
296+
return graphene.Field(_type, resolver=field_resolver if field_resolver else cached_reference_resolver,
297+
description=get_field_description(field, registry), required=required)
294298

295299
return graphene.Dynamic(dynamic_type)
296300

@@ -314,14 +318,16 @@ def dynamic_type():
314318
if not _type:
315319
return None
316320
field_resolver = None
321+
required = False
317322
if field.db_field is not None:
323+
required = field.required
318324
resolver_function = getattr(_type, "resolve_" + field.db_field, None)
319325
if resolver_function and callable(resolver_function):
320326
field_resolver = resolver_function
321327
return graphene.Field(
322328
_type,
323329
resolver=field_resolver if field_resolver else lazy_resolver,
324-
description=get_field_description(field, registry),
330+
description=get_field_description(field, registry), required=required,
325331
)
326332

327333
return graphene.Dynamic(dynamic_type)

graphene_mongo/fields.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def get_reference_field(r, kv):
178178
if callable(getattr(field, "get_type", None)):
179179
_type = field.get_type()
180180
if _type:
181-
node = _type._type._meta
181+
node = _type.type._meta if hasattr(_type.type, "_meta") else _type.type._of_type._meta
182182
if "id" in node.fields and not issubclass(
183183
node.model, (mongoengine.EmbeddedDocument,)
184184
):
@@ -197,7 +197,7 @@ def get_queryset(self, model, info, required_fields=list(), skip=None, limit=Non
197197
reference_fields = get_model_reference_fields(self.model)
198198
hydrated_references = {}
199199
for arg_name, arg in args.copy().items():
200-
if arg_name in reference_fields:
200+
if arg_name in reference_fields and isinstance(arg, str):
201201
reference_obj = get_node_from_global_id(
202202
reference_fields[arg_name], info, args.pop(arg_name)
203203
)

graphene_mongo/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .utils import get_model_fields, is_valid_mongoengine_model, get_query_fields
1414

1515

16-
def construct_fields(model, registry, required_fields, exclude_fields):
16+
def construct_fields(model, registry, only_fields, exclude_fields):
1717
"""
1818
Args:
1919
model (mongoengine.Document):
@@ -29,7 +29,7 @@ def construct_fields(model, registry, required_fields, exclude_fields):
2929
fields = OrderedDict()
3030
self_referenced = OrderedDict()
3131
for name, field in _model_fields.items():
32-
is_not_in_only = required_fields and name not in required_fields
32+
is_not_in_only = only_fields and name not in only_fields
3333
is_excluded = name in exclude_fields
3434
if is_not_in_only or is_excluded:
3535
# We skip this field if we specify required_fields and is not

0 commit comments

Comments
 (0)