Skip to content

Commit ef5c214

Browse files
committed
feat: Enables user customization by allowing subclassing and runtime filters on Connection fields.
• Subclassing or replacing automatically attached/converted MongoengineConnectionField is possible by specifying `connection_field_class` in the ObjectType meta. • A new API for specifying a callable that can return custom filters or a queryset when generating a connection iterable: `get_queryset`. • Refactors MongoengineConnectionField to use existing graphene connection functions instead of reinventing list slicing. • Refactors MongoengineConnectionField to do object lookup as an instance method instead of doing wacky things with the connection class methods. Should make subclassing get_queryset and default_resolver much more useful and intuitive.
1 parent 65fd549 commit ef5c214

File tree

7 files changed

+74
-85
lines changed

7 files changed

+74
-85
lines changed

graphene_mongo/converter.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,6 @@ def convert_field_to_datetime(field, registry=None):
8686
@convert_mongoengine_field.register(mongoengine.ListField)
8787
@convert_mongoengine_field.register(mongoengine.EmbeddedDocumentListField)
8888
def convert_field_to_list(field, registry=None):
89-
from .fields import MongoengineConnectionField
90-
9189
base_type = convert_mongoengine_field(field.field, registry=registry)
9290
if isinstance(base_type, (Dynamic)):
9391
base_type = base_type.get_type()
@@ -96,7 +94,7 @@ def convert_field_to_list(field, registry=None):
9694
base_type = base_type._type
9795

9896
if is_node(base_type):
99-
return MongoengineConnectionField(base_type)
97+
return base_type._meta.connection_field_class(base_type)
10098

10199
# Non-relationship field
102100
relations = (mongoengine.ReferenceField, mongoengine.EmbeddedDocumentField)

graphene_mongo/fields.py

Lines changed: 47 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,26 @@
44
from functools import partial, reduce
55

66
import mongoengine
7+
from graphene import PageInfo
78
from graphene.relay import ConnectionField
8-
from graphene.relay.connection import PageInfo
99
from graphene.types.argument import to_arguments
1010
from graphene.types.dynamic import Dynamic
1111
from graphene.types.structures import Structure, List
12-
from graphql_relay import from_global_id
1312
from graphql_relay.connection.arrayconnection import connection_from_list_slice
1413

1514
from .advanced_types import PointFieldType, MultiPolygonFieldType
1615
from .converter import convert_mongoengine_field, MongoEngineConversionError
1716
from .registry import get_global_registry
18-
from .utils import get_model_reference_fields, node_from_global_id
17+
from .utils import get_model_reference_fields, global_id_via_node
1918

2019

2120
class MongoengineConnectionField(ConnectionField):
2221

2322
def __init__(self, type, *args, **kwargs):
23+
get_queryset = kwargs.pop('get_queryset', None)
24+
if get_queryset:
25+
assert callable(get_queryset), "Attribute `get_queryset` on {} must be callable.".format(self)
26+
self._get_queryset = get_queryset
2427
super(MongoengineConnectionField, self).__init__(
2528
type,
2629
*args,
@@ -109,91 +112,65 @@ def get_reference_field(r, kv):
109112
def fields(self):
110113
return self._type._meta.fields
111114

112-
@classmethod
113-
def get_query(cls, model, connection, info, **args):
114-
115-
if not callable(getattr(model, 'objects', None)):
115+
def get_queryset(self, model, info, **args):
116+
if self._get_queryset:
117+
queryset_or_filters = self._get_queryset(model, info, **args)
118+
if isinstance(queryset_or_filters, mongoengine.QuerySet):
119+
return queryset_or_filters
120+
else:
121+
return model.objects(**queryset_or_filters)
122+
return model.objects()
123+
124+
def default_resolver(self, _root, info, **args):
125+
if not callable(getattr(self.model, 'objects', None)):
116126
return [], 0
117127

118-
objs = model.objects()
128+
args = args or {}
129+
130+
connection_args = {
131+
'first': args.pop('first', None),
132+
'last': args.pop('last', None),
133+
'before': args.pop('before', None),
134+
'after': args.pop('after', None)
135+
}
136+
137+
objs = self.get_queryset(self.model, info, **args)
138+
119139
if args:
120-
reference_fields = get_model_reference_fields(model)
140+
reference_fields = get_model_reference_fields(self.model)
121141
reference_args = {}
122142
for arg_name, arg in args.copy().items():
123143
if arg_name in reference_fields:
124-
reference_model = model._fields[arg_name]
125-
pk = node_from_global_id(connection, args.pop(arg_name))[-1]
144+
reference_model = self.model._fields[arg_name]
145+
pk = global_id_via_node(self.node_type, args.pop(arg_name))[-1]
126146
reference_obj = reference_model.document_type_obj.objects(pk=pk).get()
127147
reference_args[arg_name] = reference_obj
128148

129149
args.update(reference_args)
130-
first = args.pop('first', None)
131-
last = args.pop('last', None)
132150
_id = args.pop('id', None)
133-
before = args.pop('before', None)
134-
after = args.pop('after', None)
135-
136151
if _id is not None:
137-
# https://github.com/graphql-python/graphene/issues/124
138-
args['pk'] = node_from_global_id(connection, _id)[-1]
152+
args['pk'] = global_id_via_node(self.node_type, _id)[-1]
139153

140154
objs = objs.filter(**args)
141155

142-
# https://github.com/graphql-python/graphene-mongo/issues/21
143-
if after is not None:
144-
_after = int(from_global_id(after)[-1])
145-
objs = objs[_after:]
146-
147-
if before is not None:
148-
_before = int(from_global_id(before)[-1])
149-
objs = objs[:_before]
150-
151-
list_length = objs.count()
152-
153-
if first is not None:
154-
objs = objs[:first]
155-
if last is not None:
156-
# https://github.com/graphql-python/graphene-mongo/issues/20
157-
objs = objs[max(0, list_length - last):]
158-
else:
159-
list_length = objs.count()
160-
161-
return objs, list_length
162-
163-
# noqa
164-
@classmethod
165-
def merge_querysets(cls, default_queryset, queryset):
166-
return queryset & default_queryset
167-
168-
"""
169-
Notes: Not sure how does this work :(
170-
"""
171-
@classmethod
172-
def connection_resolver(cls, resolver, connection, model, root, info, **args):
173-
iterable = resolver(root, info, **args)
174-
175-
if iterable or iterable == []:
176-
_len = len(iterable)
177-
else:
178-
iterable, _len = cls.get_query(model, connection, info, **args)
179-
180-
if root:
181-
# If we have a root, we must be at least 1 layer in, right?
182-
_len = 0
183-
184156
connection = connection_from_list_slice(
185-
iterable,
186-
args,
187-
slice_start=0,
188-
list_length=_len,
189-
list_slice_length=_len,
190-
connection_type=connection,
157+
list_slice=objs,
158+
args=connection_args,
159+
list_length=objs.count(),
160+
connection_type=self.type,
161+
edge_type=self.type.Edge,
191162
pageinfo_type=PageInfo,
192-
edge_type=connection.Edge,
193163
)
194-
connection.iterable = iterable
195-
connection.length = _len
164+
connection.iterable = objs
196165
return connection
197166

167+
def chained_resolver(self, resolver, root, info, **args):
168+
resolved = resolver(root, info, **args)
169+
if resolved is not None:
170+
return resolved
171+
return self.default_resolver(root, info, **args)
172+
198173
def get_resolver(self, parent_resolver):
199-
return partial(self.connection_resolver, parent_resolver, self.type, self.model)
174+
super_resolver = self.resolver or parent_resolver
175+
resolver = partial(self.chained_resolver, super_resolver)
176+
return partial(self.connection_resolver, resolver, self.type)

graphene_mongo/tests/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,4 @@ def fixtures():
143143
child3.parent = child4.parent = parent
144144
child3.save()
145145
child4.save()
146+
return True

graphene_mongo/tests/test_relay_query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ class Query(graphene.ObjectType):
514514
players = MongoengineConnectionField(PlayerNode)
515515

516516
query = '''
517-
query EditorQuery {
517+
query PlayerQuery {
518518
players(last: 2) {
519519
edges {
520520
cursor,

graphene_mongo/types.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from collections import OrderedDict
22

3-
from graphene import Field
3+
from graphene import Field, ConnectionField
44
from graphene.relay import Connection, Node
55
from graphene.types.objecttype import ObjectType, ObjectTypeOptions
66
from graphene.types.utils import yank_fields_from_attrs
77
from mongoengine import ListField
88

9+
from graphene_mongo import MongoengineConnectionField
910
from .converter import convert_mongoengine_field
1011
from .registry import Registry, get_global_registry
1112
from .utils import (get_model_fields, is_valid_mongoengine_model)
@@ -61,18 +62,20 @@ class MongoengineObjectType(ObjectType):
6162

6263
@classmethod
6364
def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False,
64-
only_fields=(), exclude_fields=(), filter_fields=None, connection=None,
65-
connection_class=None, use_connection=None, interfaces=(), **options):
65+
only_fields=(), exclude_fields=(), filter_fields=None,
66+
connection=None, connection_class=None, use_connection=None,
67+
connection_field_class=None, interfaces=(), **options):
6668

6769
assert is_valid_mongoengine_model(model), (
68-
'You need to pass a valid Mongoengine Model in {}.Meta, received "{}".'
69-
).format(cls.__name__, model)
70+
'The attribute model in {}.Meta must be a valid Mongoengine Model. '
71+
'Received "{}" instead.'
72+
).format(cls.__name__, type(model))
7073

7174
if not registry:
7275
registry = get_global_registry()
7376

7477
assert isinstance(registry, Registry), (
75-
'The attribute registry in {} needs to be an instance of '
78+
'The attribute registry in {}.Meta needs to be an instance of '
7679
'Registry, received "{}".'
7780
).format(cls.__name__, registry)
7881

@@ -93,15 +96,25 @@ def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=Fa
9396

9497
if connection is not None:
9598
assert issubclass(connection, Connection), (
96-
'The connection must be a Connection. Received {}'
97-
).format(connection.__name__)
99+
'The attribute connection in {}.Meta must be of type Connection. '
100+
'Received "{}" instead.'
101+
).format(cls.__name__, type(connection))
102+
103+
if connection_field_class is not None:
104+
assert issubclass(connection_field_class, ConnectionField), (
105+
'The attribute connection_field_class in {}.Meta must be of type ConnectionField. '
106+
'Received "{}" instead.'
107+
).format(cls.__name__, type(connection_field_class))
108+
else:
109+
connection_field_class = MongoengineConnectionField
98110

99111
_meta = MongoengineObjectTypeOptions(cls)
100112
_meta.model = model
101113
_meta.registry = registry
102114
_meta.fields = mongoengine_fields
103115
_meta.filter_fields = filter_fields
104116
_meta.connection = connection
117+
_meta.connection_field_class = connection_field_class
105118
# Save them for later
106119
_meta.only_fields = only_fields
107120
_meta.exclude_fields = exclude_fields

graphene_mongo/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def get_field_description(field, registry=None):
9191
return "\n".join(parts)
9292

9393

94-
def node_from_global_id(connection, _id):
94+
def global_id_via_node(node, _id):
9595
try:
96-
for interface in connection._meta.node._meta.interfaces:
96+
for interface in node._meta.interfaces:
9797
if issubclass(interface, Node):
9898
return interface.from_global_id(_id)
9999
except AttributeError:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name='graphene-mongo',
5-
version='0.1.19',
5+
version='0.2.0',
66

77
description='Graphene Mongoengine integration',
88
long_description=open('README.rst').read(),

0 commit comments

Comments
 (0)