Skip to content

Commit 726389e

Browse files
Implemented NonNull if fields are marked as required in db for all type Reference fields including [Union Type]
Bug fixed in queryset reference args
1 parent 89bc945 commit 726389e

File tree

5 files changed

+25
-12
lines changed

5 files changed

+25
-12
lines changed

graphene_mongo/converter.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,14 @@ def reference_resolver(root, *args, **kwargs):
237237

238238
if isinstance(field, mongoengine.GenericReferenceField):
239239
field_resolver = None
240+
required = False
240241
if field.db_field is not None:
242+
required = field.required
241243
resolver_function = getattr(_union, "resolve_" + field.db_field, None)
242244
if resolver_function and callable(resolver_function):
243245
field_resolver = resolver_function
244246
return graphene.Field(_union, resolver=field_resolver if field_resolver else reference_resolver,
245-
description=get_field_description(field, registry))
247+
description=get_field_description(field, registry), required=required)
246248

247249
return graphene.Field(_union)
248250

@@ -281,16 +283,18 @@ def dynamic_type():
281283
return graphene.Field(_type,
282284
description=get_field_description(field, registry))
283285
field_resolver = None
286+
required = False
284287
if field.db_field is not None:
288+
required = field.required
285289
resolver_function = getattr(_type, "resolve_" + field.db_field, None)
286290
if resolver_function and callable(resolver_function):
287291
field_resolver = resolver_function
288292
if isinstance(field, mongoengine.ReferenceField):
289293
return graphene.Field(_type, resolver=field_resolver if field_resolver else reference_resolver,
290-
description=get_field_description(field, registry))
294+
description=get_field_description(field, registry), required=required)
291295
else:
292-
return graphene.Field(_type, resolver=field_resolver if field_resolver else cached_reference_resolver(),
293-
description=get_field_description(field, registry))
296+
return graphene.Field(_type, resolver=field_resolver if field_resolver else cached_reference_resolver,
297+
description=get_field_description(field, registry), required=required)
294298

295299
return graphene.Dynamic(dynamic_type)
296300

@@ -314,14 +318,16 @@ def dynamic_type():
314318
if not _type:
315319
return None
316320
field_resolver = None
321+
required = False
317322
if field.db_field is not None:
323+
required = field.required
318324
resolver_function = getattr(_type, "resolve_" + field.db_field, None)
319325
if resolver_function and callable(resolver_function):
320326
field_resolver = resolver_function
321327
return graphene.Field(
322328
_type,
323329
resolver=field_resolver if field_resolver else lazy_resolver,
324-
description=get_field_description(field, registry),
330+
description=get_field_description(field, registry), required=required,
325331
)
326332

327333
return graphene.Dynamic(dynamic_type)

graphene_mongo/fields.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ def is_filterable(k):
117117
),
118118
):
119119
return False
120+
if getattr(converted, "type", None) and getattr(converted.type, "_of_type", None) and issubclass(
121+
(get_type(converted.type.of_type)), graphene.Union):
122+
return False
120123
if isinstance(converted, (graphene.List)) and issubclass(
121124
getattr(converted, "_of_type", None), graphene.Union
122125
):
@@ -176,7 +179,7 @@ def get_reference_field(r, kv):
176179
if callable(getattr(field, "get_type", None)):
177180
_type = field.get_type()
178181
if _type:
179-
node = _type._type._meta
182+
node = _type.type._meta if hasattr(_type.type, "_meta") else _type.type._of_type._meta
180183
if "id" in node.fields and not issubclass(
181184
node.model, (mongoengine.EmbeddedDocument,)
182185
):
@@ -195,7 +198,7 @@ def get_queryset(self, model, info, required_fields=list(), **args):
195198
reference_fields = get_model_reference_fields(self.model)
196199
hydrated_references = {}
197200
for arg_name, arg in args.copy().items():
198-
if arg_name in reference_fields:
201+
if arg_name in reference_fields and isinstance(arg, str):
199202
reference_obj = get_node_from_global_id(
200203
reference_fields[arg_name], info, args.pop(arg_name)
201204
)

graphene_mongo/tests/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class Reporter(mongoengine.Document):
8484
mongoengine.EmbeddedDocumentField(EmbeddedArticle)
8585
)
8686
embedded_list_articles = mongoengine.EmbeddedDocumentListField(EmbeddedArticle)
87-
generic_reference = mongoengine.GenericReferenceField(choices=[Article, Editor])
87+
generic_reference = mongoengine.GenericReferenceField(choices=[Article, Editor],required=True)
8888
generic_embedded_document = mongoengine.GenericEmbeddedDocumentField(
8989
choices=[EmbeddedArticle, EmbeddedFoo]
9090
)

graphene_mongo/tests/test_converter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,12 @@ class Meta:
352352
Reporter._fields["generic_reference"], registry.get_global_registry()
353353
)
354354
assert isinstance(generic_reference_field, graphene.Field)
355-
assert isinstance(generic_reference_field.type(), graphene.Union)
356-
assert generic_reference_field.type()._meta.types == (A, E)
355+
if not Reporter._fields["generic_reference"].required:
356+
assert isinstance(generic_reference_field.type(), graphene.Union)
357+
assert generic_reference_field.type()._meta.types == (A, E)
358+
else:
359+
assert issubclass(generic_reference_field.type.of_type, graphene.Union)
360+
assert generic_reference_field.type.of_type._meta.types == (A, E)
357361

358362

359363
def test_should_generic_embedded_document_convert_union():

graphene_mongo/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .utils import get_model_fields, is_valid_mongoengine_model, get_query_fields
1414

1515

16-
def construct_fields(model, registry, required_fields, exclude_fields):
16+
def construct_fields(model, registry, only_fields, exclude_fields):
1717
"""
1818
Args:
1919
model (mongoengine.Document):
@@ -29,7 +29,7 @@ def construct_fields(model, registry, required_fields, exclude_fields):
2929
fields = OrderedDict()
3030
self_referenced = OrderedDict()
3131
for name, field in _model_fields.items():
32-
is_not_in_only = required_fields and name not in required_fields
32+
is_not_in_only = only_fields and name not in only_fields
3333
is_excluded = name in exclude_fields
3434
if is_not_in_only or is_excluded:
3535
# We skip this field if we specify required_fields and is not

0 commit comments

Comments
 (0)