|
4 | 4 | from functools import partial, reduce
|
5 | 5 |
|
6 | 6 | import mongoengine
|
| 7 | +from graphene import PageInfo |
7 | 8 | from graphene.relay import ConnectionField
|
8 |
| -from graphene.relay.connection import PageInfo |
9 | 9 | from graphene.types.argument import to_arguments
|
10 | 10 | from graphene.types.dynamic import Dynamic
|
11 | 11 | from graphene.types.structures import Structure, List
|
12 |
| -from graphql_relay import from_global_id |
13 | 12 | from graphql_relay.connection.arrayconnection import connection_from_list_slice
|
14 | 13 |
|
15 | 14 | from .advanced_types import PointFieldType, MultiPolygonFieldType
|
16 | 15 | from .converter import convert_mongoengine_field, MongoEngineConversionError
|
17 | 16 | from .registry import get_global_registry
|
18 |
| -from .utils import get_model_reference_fields, node_from_global_id |
| 17 | +from .utils import get_model_reference_fields, global_id_via_node |
19 | 18 |
|
20 | 19 |
|
21 | 20 | class MongoengineConnectionField(ConnectionField):
|
22 | 21 |
|
23 | 22 | def __init__(self, type, *args, **kwargs):
|
| 23 | + get_queryset = kwargs.pop('get_queryset', None) |
| 24 | + if get_queryset: |
| 25 | + assert callable(get_queryset), "Attribute `get_queryset` on {} must be callable.".format(self) |
| 26 | + self._get_queryset = get_queryset |
24 | 27 | super(MongoengineConnectionField, self).__init__(
|
25 | 28 | type,
|
26 | 29 | *args,
|
@@ -109,91 +112,65 @@ def get_reference_field(r, kv):
|
109 | 112 | def fields(self):
|
110 | 113 | return self._type._meta.fields
|
111 | 114 |
|
112 |
| - @classmethod |
113 |
| - def get_query(cls, model, connection, info, **args): |
114 |
| - |
115 |
| - if not callable(getattr(model, 'objects', None)): |
| 115 | + def get_queryset(self, model, info, **args): |
| 116 | + if self._get_queryset: |
| 117 | + queryset_or_filters = self._get_queryset(model, info, **args) |
| 118 | + if isinstance(queryset_or_filters, mongoengine.QuerySet): |
| 119 | + return queryset_or_filters |
| 120 | + else: |
| 121 | + return model.objects(**queryset_or_filters) |
| 122 | + return model.objects() |
| 123 | + |
| 124 | + def default_resolver(self, _root, info, **args): |
| 125 | + if not callable(getattr(self.model, 'objects', None)): |
116 | 126 | return [], 0
|
117 | 127 |
|
118 |
| - objs = model.objects() |
| 128 | + args = args or {} |
| 129 | + |
| 130 | + connection_args = { |
| 131 | + 'first': args.pop('first', None), |
| 132 | + 'last': args.pop('last', None), |
| 133 | + 'before': args.pop('before', None), |
| 134 | + 'after': args.pop('after', None) |
| 135 | + } |
| 136 | + |
| 137 | + objs = self.get_queryset(self.model, info, **args) |
| 138 | + |
119 | 139 | if args:
|
120 |
| - reference_fields = get_model_reference_fields(model) |
| 140 | + reference_fields = get_model_reference_fields(self.model) |
121 | 141 | reference_args = {}
|
122 | 142 | for arg_name, arg in args.copy().items():
|
123 | 143 | if arg_name in reference_fields:
|
124 |
| - reference_model = model._fields[arg_name] |
125 |
| - pk = node_from_global_id(connection, args.pop(arg_name))[-1] |
| 144 | + reference_model = self.model._fields[arg_name] |
| 145 | + pk = global_id_via_node(self.node_type, args.pop(arg_name))[-1] |
126 | 146 | reference_obj = reference_model.document_type_obj.objects(pk=pk).get()
|
127 | 147 | reference_args[arg_name] = reference_obj
|
128 | 148 |
|
129 | 149 | args.update(reference_args)
|
130 |
| - first = args.pop('first', None) |
131 |
| - last = args.pop('last', None) |
132 | 150 | _id = args.pop('id', None)
|
133 |
| - before = args.pop('before', None) |
134 |
| - after = args.pop('after', None) |
135 |
| - |
136 | 151 | if _id is not None:
|
137 |
| - # https://github.com/graphql-python/graphene/issues/124 |
138 |
| - args['pk'] = node_from_global_id(connection, _id)[-1] |
| 152 | + args['pk'] = global_id_via_node(self.node_type, _id)[-1] |
139 | 153 |
|
140 | 154 | objs = objs.filter(**args)
|
141 | 155 |
|
142 |
| - # https://github.com/graphql-python/graphene-mongo/issues/21 |
143 |
| - if after is not None: |
144 |
| - _after = int(from_global_id(after)[-1]) |
145 |
| - objs = objs[_after:] |
146 |
| - |
147 |
| - if before is not None: |
148 |
| - _before = int(from_global_id(before)[-1]) |
149 |
| - objs = objs[:_before] |
150 |
| - |
151 |
| - list_length = objs.count() |
152 |
| - |
153 |
| - if first is not None: |
154 |
| - objs = objs[:first] |
155 |
| - if last is not None: |
156 |
| - # https://github.com/graphql-python/graphene-mongo/issues/20 |
157 |
| - objs = objs[max(0, list_length - last):] |
158 |
| - else: |
159 |
| - list_length = objs.count() |
160 |
| - |
161 |
| - return objs, list_length |
162 |
| - |
163 |
| - # noqa |
164 |
| - @classmethod |
165 |
| - def merge_querysets(cls, default_queryset, queryset): |
166 |
| - return queryset & default_queryset |
167 |
| - |
168 |
| - """ |
169 |
| - Notes: Not sure how does this work :( |
170 |
| - """ |
171 |
| - @classmethod |
172 |
| - def connection_resolver(cls, resolver, connection, model, root, info, **args): |
173 |
| - iterable = resolver(root, info, **args) |
174 |
| - |
175 |
| - if iterable or iterable == []: |
176 |
| - _len = len(iterable) |
177 |
| - else: |
178 |
| - iterable, _len = cls.get_query(model, connection, info, **args) |
179 |
| - |
180 |
| - if root: |
181 |
| - # If we have a root, we must be at least 1 layer in, right? |
182 |
| - _len = 0 |
183 |
| - |
184 | 156 | connection = connection_from_list_slice(
|
185 |
| - iterable, |
186 |
| - args, |
187 |
| - slice_start=0, |
188 |
| - list_length=_len, |
189 |
| - list_slice_length=_len, |
190 |
| - connection_type=connection, |
| 157 | + list_slice=objs, |
| 158 | + args=connection_args, |
| 159 | + list_length=objs.count(), |
| 160 | + connection_type=self.type, |
| 161 | + edge_type=self.type.Edge, |
191 | 162 | pageinfo_type=PageInfo,
|
192 |
| - edge_type=connection.Edge, |
193 | 163 | )
|
194 |
| - connection.iterable = iterable |
195 |
| - connection.length = _len |
| 164 | + connection.iterable = objs |
196 | 165 | return connection
|
197 | 166 |
|
| 167 | + def chained_resolver(self, resolver, root, info, **args): |
| 168 | + resolved = resolver(root, info, **args) |
| 169 | + if resolved is not None: |
| 170 | + return resolved |
| 171 | + return self.default_resolver(root, info, **args) |
| 172 | + |
198 | 173 | def get_resolver(self, parent_resolver):
|
199 |
| - return partial(self.connection_resolver, parent_resolver, self.type, self.model) |
| 174 | + super_resolver = self.resolver or parent_resolver |
| 175 | + resolver = partial(self.chained_resolver, super_resolver) |
| 176 | + return partial(self.connection_resolver, resolver, self.type) |
0 commit comments