|
10 | 10 | from graphene.types.utils import get_type
|
11 | 11 | from graphene.utils.str_converters import to_snake_case
|
12 | 12 | from graphql import ResolveInfo
|
| 13 | +from mongoengine.base import get_document |
13 | 14 | from promise import Promise
|
14 | 15 | from graphql_relay import from_global_id
|
15 | 16 | from graphene.relay import ConnectionField
|
@@ -177,6 +178,9 @@ def get_reference_field(r, kv):
|
177 | 178 | (mongoengine.LazyReferenceField, mongoengine.ReferenceField),
|
178 | 179 | ):
|
179 | 180 | field = convert_mongoengine_field(mongo_field, self.registry)
|
| 181 | + if isinstance(mongo_field, mongoengine.GenericReferenceField): |
| 182 | + r.update({kv[0]: graphene.ID()}) |
| 183 | + return r |
180 | 184 | if callable(getattr(field, "get_type", None)):
|
181 | 185 | _type = field.get_type()
|
182 | 186 | if _type:
|
@@ -207,6 +211,15 @@ def get_queryset(self, model, info, required_fields=list(), skip=None, limit=Non
|
207 | 211 | except TypeError:
|
208 | 212 | reference_obj = reference_fields[arg_name].document_type(pk=arg)
|
209 | 213 | hydrated_references[arg_name] = reference_obj
|
| 214 | + elif arg_name in self.model._fields_ordered and isinstance(getattr(self.model, arg_name), |
| 215 | + mongoengine.fields.GenericReferenceField): |
| 216 | + try: |
| 217 | + reference_obj = get_document(self.registry._registry_string_map[from_global_id(arg)[0]])( |
| 218 | + pk=from_global_id(arg)[1]) |
| 219 | + except TypeError: |
| 220 | + reference_obj = get_document(arg["_cls"])( |
| 221 | + pk=arg["_ref"].id) |
| 222 | + hydrated_references[arg_name] = reference_obj |
210 | 223 | elif arg_name == "id":
|
211 | 224 | hydrated_references["id"] = from_global_id(args.pop("id", None))[1]
|
212 | 225 | args.update(hydrated_references)
|
@@ -286,7 +299,7 @@ def default_resolver(self, _root, info, required_fields=list(), **args):
|
286 | 299 | if not info.context:
|
287 | 300 | info.context = Context()
|
288 | 301 | info.context.queryset = self.get_queryset(self.model, info, required_fields, **args)
|
289 |
| - elif _root is None: |
| 302 | + elif _root is None or args: |
290 | 303 | count = self.get_queryset(self.model, info, required_fields, **args).count()
|
291 | 304 | if count != 0:
|
292 | 305 | skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before,
|
@@ -364,6 +377,9 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
|
364 | 377 | if arg_name not in self.model._fields_ordered + ('first', 'last', 'before', 'after') + tuple(
|
365 | 378 | self.filter_args.keys()):
|
366 | 379 | args_copy.pop(arg_name)
|
| 380 | + if '.' in arg_name: |
| 381 | + operation = list(arg.keys())[0] |
| 382 | + args_copy[arg_name.replace('.', '__') + operation.replace('$', '__')] = arg[operation] |
367 | 383 | return self.default_resolver(root, info, required_fields, **args_copy)
|
368 | 384 | else:
|
369 | 385 | return resolved
|
|
0 commit comments