5
5
6
6
import graphene
7
7
import mongoengine
8
+ from graphql .utils .ast_to_dict import ast_to_dict
8
9
from promise import Promise
9
10
from graphql_relay import from_global_id
10
11
from graphene .relay import ConnectionField
21
22
)
22
23
from .converter import convert_mongoengine_field , MongoEngineConversionError
23
24
from .registry import get_global_registry
24
- from .utils import get_model_reference_fields , get_node_from_global_id
25
+ from .utils import get_model_reference_fields , get_node_from_global_id , get_query_fields , camel_to_snake
25
26
26
27
27
28
class MongoengineConnectionField (ConnectionField ):
@@ -98,18 +99,18 @@ def is_filterable(k):
98
99
if isinstance (converted , (ConnectionField , Dynamic )):
99
100
return False
100
101
if callable (getattr (converted , "type" , None )) and isinstance (
101
- converted .type (),
102
- (
103
- FileFieldType ,
104
- PointFieldType ,
105
- MultiPolygonFieldType ,
106
- graphene .Union ,
107
- PolygonFieldType ,
108
- ),
102
+ converted .type (),
103
+ (
104
+ FileFieldType ,
105
+ PointFieldType ,
106
+ MultiPolygonFieldType ,
107
+ graphene .Union ,
108
+ PolygonFieldType ,
109
+ ),
109
110
):
110
111
return False
111
112
if isinstance (converted , (graphene .List )) and issubclass (
112
- getattr (converted , "_of_type" , None ), graphene .Union
113
+ getattr (converted , "_of_type" , None ), graphene .Union
113
114
):
114
115
return False
115
116
@@ -160,16 +161,16 @@ def get_reference_field(r, kv):
160
161
field = kv [1 ]
161
162
mongo_field = getattr (self .model , kv [0 ], None )
162
163
if isinstance (
163
- mongo_field ,
164
- (mongoengine .LazyReferenceField , mongoengine .ReferenceField ),
164
+ mongo_field ,
165
+ (mongoengine .LazyReferenceField , mongoengine .ReferenceField ),
165
166
):
166
167
field = convert_mongoengine_field (mongo_field , self .registry )
167
168
if callable (getattr (field , "get_type" , None )):
168
169
_type = field .get_type ()
169
170
if _type :
170
171
node = _type ._type ._meta
171
172
if "id" in node .fields and not issubclass (
172
- node .model , (mongoengine .EmbeddedDocument ,)
173
+ node .model , (mongoengine .EmbeddedDocument ,)
173
174
):
174
175
r .update ({kv [0 ]: node .fields ["id" ]._type .of_type ()})
175
176
return r
@@ -180,7 +181,7 @@ def get_reference_field(r, kv):
180
181
def fields (self ):
181
182
return self ._type ._meta .fields
182
183
183
- def get_queryset (self , model , info , ** args ):
184
+ def get_queryset (self , model , info , only_fields = list (), ** args ):
184
185
if args :
185
186
reference_fields = get_model_reference_fields (self .model )
186
187
hydrated_references = {}
@@ -198,12 +199,13 @@ def get_queryset(self, model, info, **args):
198
199
return queryset_or_filters
199
200
else :
200
201
args .update (queryset_or_filters )
202
+
201
203
return model .objects (** args ).order_by (self .order_by )
202
204
203
- def default_resolver (self , _root , info , ** args ):
205
+ def default_resolver (self , _root , info , only_fields = list (), ** args ):
204
206
args = args or {}
205
207
206
- if _root is not None :
208
+ if _root is not None and getattr ( _root , info . field_name , []) is not None :
207
209
args ["pk__in" ] = [r .pk for r in getattr (_root , info .field_name , [])]
208
210
209
211
connection_args = {
@@ -219,7 +221,7 @@ def default_resolver(self, _root, info, **args):
219
221
args ['pk' ] = from_global_id (_id )[- 1 ]
220
222
221
223
if callable (getattr (self .model , "objects" , None )):
222
- iterables = self .get_queryset (self .model , info , ** args )
224
+ iterables = self .get_queryset (self .model , info , only_fields , ** args )
223
225
list_length = iterables .count ()
224
226
else :
225
227
iterables = []
@@ -239,23 +241,31 @@ def default_resolver(self, _root, info, **args):
239
241
return connection
240
242
241
243
def chained_resolver (self , resolver , is_partial , root , info , ** args ):
244
+ only_fields = list ()
245
+ for field in get_query_fields (info ):
246
+ if camel_to_snake (field ) in self .model ._fields_ordered :
247
+ only_fields .append (camel_to_snake (field ))
242
248
if not bool (args ) or not is_partial :
249
+ if isinstance (self .model , mongoengine .Document ) or isinstance (self .model ,
250
+ mongoengine .base .metaclasses .TopLevelDocumentMetaclass ):
251
+ args_copy = args .copy ()
252
+ for arg_name , arg in args .copy ().items ():
253
+ if arg_name not in self .model ._fields_ordered :
254
+ args_copy .pop (arg_name )
243
255
# XXX: Filter nested args
244
256
resolved = resolver (root , info , ** args )
245
257
if resolved is not None :
246
258
return resolved
247
- return self .default_resolver (root , info , ** args )
259
+ return self .default_resolver (root , info , only_fields , ** args )
248
260
249
261
@classmethod
250
262
def connection_resolver (cls , resolver , connection_type , root , info , ** args ):
251
263
iterable = resolver (root , info , ** args )
252
264
if isinstance (connection_type , graphene .NonNull ):
253
265
connection_type = connection_type .of_type
254
-
255
266
on_resolve = partial (cls .resolve_connection , connection_type , args )
256
267
if Promise .is_thenable (iterable ):
257
268
return Promise .resolve (iterable ).then (on_resolve )
258
-
259
269
return on_resolve (iterable )
260
270
261
271
def get_resolver (self , parent_resolver ):
0 commit comments