Skip to content

Commit 65fd549

Browse files
committed
feat: Add support for Node interface subclasses that override global_id functions.
1 parent 8616631 commit 65fd549

File tree

3 files changed

+62
-26
lines changed

3 files changed

+62
-26
lines changed

graphene_mongo/converter.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@
1616
import mongoengine
1717

1818
from .advanced_types import PointFieldType, MultiPolygonFieldType
19-
from .fields import MongoengineConnectionField
2019
from .utils import import_single_dispatch, get_field_description
2120

2221
singledispatch = import_single_dispatch()
2322

2423

24+
class MongoEngineConversionError(Exception):
25+
pass
26+
27+
2528
@singledispatch
2629
def convert_mongoengine_field(field, registry=None):
27-
raise Exception(
30+
raise MongoEngineConversionError(
2831
"Don't know how to convert the MongoEngine field %s (%s)" %
2932
(field, field.__class__))
3033

@@ -83,6 +86,8 @@ def convert_field_to_datetime(field, registry=None):
8386
@convert_mongoengine_field.register(mongoengine.ListField)
8487
@convert_mongoengine_field.register(mongoengine.EmbeddedDocumentListField)
8588
def convert_field_to_list(field, registry=None):
89+
from .fields import MongoengineConnectionField
90+
8691
base_type = convert_mongoengine_field(field.field, registry=registry)
8792
if isinstance(base_type, (Dynamic)):
8893
base_type = base_type.get_type()

graphene_mongo/fields.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
from __future__ import absolute_import
22

3-
import mongoengine
43
from collections import OrderedDict
54
from functools import partial, reduce
65

6+
import mongoengine
77
from graphene.relay import ConnectionField
88
from graphene.relay.connection import PageInfo
9-
from graphql_relay.connection.arrayconnection import connection_from_list_slice
10-
from graphql_relay.node.node import from_global_id
119
from graphene.types.argument import to_arguments
1210
from graphene.types.dynamic import Dynamic
13-
from graphene.types.structures import Structure
11+
from graphene.types.structures import Structure, List
12+
from graphql_relay import from_global_id
13+
from graphql_relay.connection.arrayconnection import connection_from_list_slice
1414

1515
from .advanced_types import PointFieldType, MultiPolygonFieldType
16-
from .utils import get_model_reference_fields
16+
from .converter import convert_mongoengine_field, MongoEngineConversionError
17+
from .registry import get_global_registry
18+
from .utils import get_model_reference_fields, node_from_global_id
1719

1820

1921
class MongoengineConnectionField(ConnectionField):
@@ -43,6 +45,10 @@ def node_type(self):
4345
def model(self):
4446
return self.node_type._meta.model
4547

48+
@property
49+
def registry(self):
50+
return getattr(self.node_type._meta, 'registry', get_global_registry())
51+
4652
@property
4753
def args(self):
4854
return to_arguments(
@@ -55,12 +61,19 @@ def args(self, args):
5561
self._base_args = args
5662

5763
def _field_args(self, items):
58-
def is_filterable(v):
59-
if isinstance(v, (ConnectionField, Dynamic)):
64+
def is_filterable(k):
65+
if not hasattr(self.model, k):
66+
return False
67+
if isinstance(getattr(self.model, k), property):
6068
return False
61-
# FIXME: Skip PointTypeField at this moment.
62-
if not isinstance(v.type, Structure) \
63-
and isinstance(v.type(), (PointFieldType, MultiPolygonFieldType)):
69+
try:
70+
converted = convert_mongoengine_field(getattr(self.model, k), self.registry)
71+
except MongoEngineConversionError:
72+
return False
73+
if isinstance(converted, (ConnectionField, Dynamic, List)):
74+
return False
75+
if callable(getattr(converted, 'type', None)) and isinstance(converted.type(),
76+
(PointFieldType, MultiPolygonFieldType)):
6477
return False
6578
return True
6679

@@ -69,7 +82,7 @@ def get_type(v):
6982
return v.type.of_type()
7083
return v.type()
7184

72-
return {k: get_type(v) for k, v in items if is_filterable(v)}
85+
return {k: get_type(v) for k, v in items if is_filterable(k)}
7386

7487
@property
7588
def field_args(self):
@@ -78,19 +91,26 @@ def field_args(self):
7891
@property
7992
def reference_args(self):
8093
def get_reference_field(r, kv):
81-
if callable(getattr(kv[1], 'get_type', None)):
82-
node = kv[1].get_type()._type._meta
83-
if not issubclass(node.model, mongoengine.EmbeddedDocument):
84-
r.update({kv[0]: node.fields['id']._type.of_type()})
94+
field = kv[1]
95+
mongo_field = getattr(self.model, kv[0], None)
96+
if isinstance(mongo_field, (mongoengine.LazyReferenceField, mongoengine.ReferenceField)):
97+
field = convert_mongoengine_field(mongo_field, self.registry)
98+
if callable(getattr(field, 'get_type', None)):
99+
_type = field.get_type()
100+
if _type:
101+
node = _type._type._meta
102+
if 'id' in node.fields and not issubclass(node.model, mongoengine.EmbeddedDocument):
103+
r.update({kv[0]: node.fields['id']._type.of_type()})
85104
return r
105+
86106
return reduce(get_reference_field, self.fields.items(), {})
87107

88108
@property
89109
def fields(self):
90110
return self._type._meta.fields
91111

92112
@classmethod
93-
def get_query(cls, model, info, **args):
113+
def get_query(cls, model, connection, info, **args):
94114

95115
if not callable(getattr(model, 'objects', None)):
96116
return [], 0
@@ -102,20 +122,20 @@ def get_query(cls, model, info, **args):
102122
for arg_name, arg in args.copy().items():
103123
if arg_name in reference_fields:
104124
reference_model = model._fields[arg_name]
105-
pk = from_global_id(args.pop(arg_name))[-1]
125+
pk = node_from_global_id(connection, args.pop(arg_name))[-1]
106126
reference_obj = reference_model.document_type_obj.objects(pk=pk).get()
107127
reference_args[arg_name] = reference_obj
108128

109129
args.update(reference_args)
110130
first = args.pop('first', None)
111131
last = args.pop('last', None)
112-
id = args.pop('id', None)
132+
_id = args.pop('id', None)
113133
before = args.pop('before', None)
114134
after = args.pop('after', None)
115135

116-
if id is not None:
136+
if _id is not None:
117137
# https://github.com/graphql-python/graphene/issues/124
118-
args['pk'] = from_global_id(id)[-1]
138+
args['pk'] = node_from_global_id(connection, _id)[-1]
119139

120140
objs = objs.filter(**args)
121141

@@ -152,14 +172,14 @@ def merge_querysets(cls, default_queryset, queryset):
152172
def connection_resolver(cls, resolver, connection, model, root, info, **args):
153173
iterable = resolver(root, info, **args)
154174

155-
if not iterable:
156-
iterable, _len = cls.get_query(model, info, **args)
175+
if iterable or iterable == []:
176+
_len = len(iterable)
177+
else:
178+
iterable, _len = cls.get_query(model, connection, info, **args)
157179

158180
if root:
159181
# If we have a root, we must be at least 1 layer in, right?
160182
_len = 0
161-
else:
162-
_len = len(iterable)
163183

164184
connection = connection_from_list_slice(
165185
iterable,

graphene_mongo/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from collections import OrderedDict
55

66
import mongoengine
7+
from graphene import Node
78
from graphene.utils.trim_docstring import trim_docstring
9+
from graphql_relay import from_global_id
810

911

1012
def get_model_fields(model, excluding=None):
@@ -87,3 +89,12 @@ def get_field_description(field, registry=None):
8789
parts.append(name_format % field.db_field)
8890

8991
return "\n".join(parts)
92+
93+
94+
def node_from_global_id(connection, _id):
95+
try:
96+
for interface in connection._meta.node._meta.interfaces:
97+
if issubclass(interface, Node):
98+
return interface.from_global_id(_id)
99+
except AttributeError:
100+
return from_global_id(_id)

0 commit comments

Comments
 (0)