Skip to content

Commit bc5f0c6

Browse files
authored
Merge pull request #164 from arunsureshkumar/support-generic-reference-field-in-args
Support generic reference field in args
2 parents 8eb3d1f + ef8b45e commit bc5f0c6

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

graphene_mongo/fields.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from graphene.types.utils import get_type
1111
from graphene.utils.str_converters import to_snake_case
1212
from graphql import ResolveInfo
13+
from mongoengine.base import get_document
1314
from promise import Promise
1415
from graphql_relay import from_global_id
1516
from graphene.relay import ConnectionField
@@ -177,6 +178,9 @@ def get_reference_field(r, kv):
177178
(mongoengine.LazyReferenceField, mongoengine.ReferenceField),
178179
):
179180
field = convert_mongoengine_field(mongo_field, self.registry)
181+
if isinstance(mongo_field, mongoengine.GenericReferenceField):
182+
r.update({kv[0]: graphene.ID()})
183+
return r
180184
if callable(getattr(field, "get_type", None)):
181185
_type = field.get_type()
182186
if _type:
@@ -207,6 +211,15 @@ def get_queryset(self, model, info, required_fields=list(), skip=None, limit=Non
207211
except TypeError:
208212
reference_obj = reference_fields[arg_name].document_type(pk=arg)
209213
hydrated_references[arg_name] = reference_obj
214+
elif arg_name in self.model._fields_ordered and isinstance(getattr(self.model, arg_name),
215+
mongoengine.fields.GenericReferenceField):
216+
try:
217+
reference_obj = get_document(self.registry._registry_string_map[from_global_id(arg)[0]])(
218+
pk=from_global_id(arg)[1])
219+
except TypeError:
220+
reference_obj = get_document(arg["_cls"])(
221+
pk=arg["_ref"].id)
222+
hydrated_references[arg_name] = reference_obj
210223
elif arg_name == "id":
211224
hydrated_references["id"] = from_global_id(args.pop("id", None))[1]
212225
args.update(hydrated_references)
@@ -286,7 +299,7 @@ def default_resolver(self, _root, info, required_fields=list(), **args):
286299
if not info.context:
287300
info.context = Context()
288301
info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
289-
elif _root is None:
302+
elif _root is None or args:
290303
count = self.get_queryset(self.model, info, required_fields, **args).count()
291304
if count != 0:
292305
skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before,
@@ -364,6 +377,9 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
364377
if arg_name not in self.model._fields_ordered + ('first', 'last', 'before', 'after') + tuple(
365378
self.filter_args.keys()):
366379
args_copy.pop(arg_name)
380+
if '.' in arg_name:
381+
operation = list(arg.keys())[0]
382+
args_copy[arg_name.replace('.', '__') + operation.replace('$', '__')] = arg[operation]
367383
return self.default_resolver(root, info, required_fields, **args_copy)
368384
else:
369385
return resolved

graphene_mongo/tests/test_converter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,8 @@ class Meta:
305305
Article._fields["pub_date"], A._meta.registry
306306
)
307307
assert (
308-
pubDate_field.kwargs["description"]
309-
== "Publication Date\nThe date of first press."
308+
pubDate_field.kwargs["description"]
309+
== "Publication Date\nThe date of first press."
310310
)
311311

312312
firstName_field = convert_mongoengine_field(

0 commit comments

Comments
 (0)