Skip to content

Commit f93e5c7

Browse files
Implemented NonNull if fields are marked as required in db for Generic Reference fields [Union Type]
1 parent 7fe84bd commit f93e5c7

File tree

5 files changed

+17
-20
lines changed

5 files changed

+17
-20
lines changed

graphene_mongo/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def reference_resolver(root, *args, **kwargs):
244244
if resolver_function and callable(resolver_function):
245245
field_resolver = resolver_function
246246
return graphene.Field(_union, resolver=field_resolver if field_resolver else reference_resolver,
247-
description=get_field_description(field, registry))
247+
description=get_field_description(field, registry), required=required)
248248

249249
return graphene.Field(_union)
250250

graphene_mongo/fields.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,13 @@ def is_filterable(k):
119119
),
120120
):
121121
return False
122+
if getattr(converted, "type", None) and getattr(converted.type, "_of_type", None) and issubclass(
123+
(get_type(converted.type.of_type)), graphene.Union):
124+
return False
122125
if isinstance(converted, (graphene.List)) and issubclass(
123126
getattr(converted, "_of_type", None), graphene.Union
124127
):
125128
return False
126-
127129
return True
128130

129131
def get_filter_type(_type):
@@ -183,6 +185,7 @@ def get_reference_field(r, kv):
183185
node.model, (mongoengine.EmbeddedDocument,)
184186
):
185187
r.update({kv[0]: node.fields["id"]._type.of_type()})
188+
186189
return r
187190

188191
return reduce(get_reference_field, self.fields.items(), {})

graphene_mongo/tests/models.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
mongoengine.connect(
77
"graphene-mongo-test", host="mongomock://localhost", alias="default"
88
)
9+
10+
911
# mongoengine.connect('graphene-mongo-test', host='mongodb://localhost/graphene-mongo-dev')
1012

1113

1214
class Publisher(mongoengine.Document):
13-
1415
meta = {"collection": "test_publisher"}
1516
name = mongoengine.StringField()
1617

@@ -42,7 +43,6 @@ class Editor(mongoengine.Document):
4243

4344

4445
class Article(mongoengine.Document):
45-
4646
meta = {"collection": "test_article"}
4747
headline = mongoengine.StringField(required=True, help_text="The article headline.")
4848
pub_date = mongoengine.DateTimeField(
@@ -58,7 +58,6 @@ class Article(mongoengine.Document):
5858

5959

6060
class EmbeddedArticle(mongoengine.EmbeddedDocument):
61-
6261
meta = {"collection": "test_embedded_article"}
6362
headline = mongoengine.StringField(required=True)
6463
pub_date = mongoengine.DateTimeField(default=datetime.now)
@@ -72,7 +71,6 @@ class EmbeddedFoo(mongoengine.EmbeddedDocument):
7271

7372

7473
class Reporter(mongoengine.Document):
75-
7674
meta = {"collection": "test_reporter"}
7775
id = mongoengine.StringField(primary_key=True)
7876
first_name = mongoengine.StringField(required=True)
@@ -84,7 +82,7 @@ class Reporter(mongoengine.Document):
8482
mongoengine.EmbeddedDocumentField(EmbeddedArticle)
8583
)
8684
embedded_list_articles = mongoengine.EmbeddedDocumentListField(EmbeddedArticle)
87-
generic_reference = mongoengine.GenericReferenceField(choices=[Article, Editor])
85+
generic_reference = mongoengine.GenericReferenceField(choices=[Article, Editor], required=True)
8886
generic_embedded_document = mongoengine.GenericEmbeddedDocumentField(
8987
choices=[EmbeddedArticle, EmbeddedFoo]
9088
)
@@ -94,7 +92,6 @@ class Reporter(mongoengine.Document):
9492

9593

9694
class Player(mongoengine.Document):
97-
9895
meta = {"collection": "test_player"}
9996
first_name = mongoengine.StringField(required=True)
10097
last_name = mongoengine.StringField(required=True)
@@ -105,29 +102,25 @@ class Player(mongoengine.Document):
105102

106103

107104
class Parent(mongoengine.Document):
108-
109105
meta = {"collection": "test_parent", "allow_inheritance": True}
110106
bar = mongoengine.StringField()
111107
loc = mongoengine.MultiPolygonField()
112108

113109

114110
class CellTower(mongoengine.Document):
115-
116111
meta = {"collection": "test_cell_tower"}
117112
code = mongoengine.StringField()
118113
base = mongoengine.PolygonField()
119114
coverage_area = mongoengine.MultiPolygonField()
120115

121116

122117
class Child(Parent):
123-
124118
meta = {"collection": "test_child"}
125119
baz = mongoengine.StringField()
126120
loc = mongoengine.PointField()
127121

128122

129123
class ProfessorMetadata(mongoengine.EmbeddedDocument):
130-
131124
meta = {"collection": "test_professor_metadata"}
132125
id = mongoengine.StringField(primary_key=False)
133126
first_name = mongoengine.StringField()
@@ -136,14 +129,12 @@ class ProfessorMetadata(mongoengine.EmbeddedDocument):
136129

137130

138131
class ProfessorVector(mongoengine.Document):
139-
140132
meta = {"collection": "test_professor_vector"}
141133
vec = mongoengine.ListField(mongoengine.FloatField())
142134
metadata = mongoengine.EmbeddedDocumentField(ProfessorMetadata)
143135

144136

145137
class ParentWithRelationship(mongoengine.Document):
146-
147138
meta = {"collection": "test_parent_reference"}
148139
before_child = mongoengine.ListField(
149140
mongoengine.ReferenceField("ChildRegisteredBefore")
@@ -155,14 +146,12 @@ class ParentWithRelationship(mongoengine.Document):
155146

156147

157148
class ChildRegisteredBefore(mongoengine.Document):
158-
159149
meta = {"collection": "test_child_before_reference"}
160150
parent = mongoengine.ReferenceField(ParentWithRelationship)
161151
name = mongoengine.StringField()
162152

163153

164154
class ChildRegisteredAfter(mongoengine.Document):
165-
166155
meta = {"collection": "test_child_after_reference"}
167156
parent = mongoengine.ReferenceField(ParentWithRelationship)
168157
name = mongoengine.StringField()

graphene_mongo/tests/test_converter.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,8 @@ class Meta:
305305
Article._fields["pub_date"], A._meta.registry
306306
)
307307
assert (
308-
pubDate_field.kwargs["description"]
309-
== "Publication Date\nThe date of first press."
308+
pubDate_field.kwargs["description"]
309+
== "Publication Date\nThe date of first press."
310310
)
311311

312312
firstName_field = convert_mongoengine_field(
@@ -352,8 +352,12 @@ class Meta:
352352
Reporter._fields["generic_reference"], registry.get_global_registry()
353353
)
354354
assert isinstance(generic_reference_field, graphene.Field)
355-
assert isinstance(generic_reference_field.type(), graphene.Union)
356-
assert generic_reference_field.type()._meta.types == (A, E)
355+
if not Reporter._fields["generic_reference"].required:
356+
assert isinstance(generic_reference_field.type(), graphene.Union)
357+
assert generic_reference_field.type()._meta.types == (A, E)
358+
else:
359+
assert issubclass(generic_reference_field.type.of_type, graphene.Union)
360+
assert generic_reference_field.type.of_type._meta.types == (A, E)
357361

358362

359363
def test_should_generic_embedded_document_convert_union():

graphene_mongo/tests/test_relay_query.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ class Query(graphene.ObjectType):
455455
}
456456
}
457457
schema = graphene.Schema(query=Query)
458+
print(schema)
458459
result = schema.execute(query)
459460
assert not result.errors
460461
assert result.data == expected

0 commit comments

Comments
 (0)