1
1
from __future__ import absolute_import
2
2
3
- import mongoengine
4
3
from collections import OrderedDict
5
4
from functools import partial , reduce
6
5
6
+ import mongoengine
7
7
from graphene .relay import ConnectionField
8
8
from graphene .relay .connection import PageInfo
9
- from graphql_relay .connection .arrayconnection import connection_from_list_slice
10
- from graphql_relay .node .node import from_global_id
11
9
from graphene .types .argument import to_arguments
12
10
from graphene .types .dynamic import Dynamic
13
- from graphene .types .structures import Structure
11
+ from graphene .types .structures import Structure , List
12
+ from graphql_relay import from_global_id
13
+ from graphql_relay .connection .arrayconnection import connection_from_list_slice
14
14
15
15
from .advanced_types import PointFieldType , MultiPolygonFieldType
16
- from .utils import get_model_reference_fields
16
+ from .converter import convert_mongoengine_field , MongoEngineConversionError
17
+ from .registry import get_global_registry
18
+ from .utils import get_model_reference_fields , node_from_global_id
17
19
18
20
19
21
class MongoengineConnectionField (ConnectionField ):
@@ -43,6 +45,10 @@ def node_type(self):
43
45
def model (self ):
44
46
return self .node_type ._meta .model
45
47
48
+ @property
49
+ def registry (self ):
50
+ return getattr (self .node_type ._meta , 'registry' , get_global_registry ())
51
+
46
52
@property
47
53
def args (self ):
48
54
return to_arguments (
@@ -55,12 +61,19 @@ def args(self, args):
55
61
self ._base_args = args
56
62
57
63
def _field_args (self , items ):
58
- def is_filterable (v ):
59
- if isinstance (v , (ConnectionField , Dynamic )):
64
+ def is_filterable (k ):
65
+ if not hasattr (self .model , k ):
66
+ return False
67
+ if isinstance (getattr (self .model , k ), property ):
60
68
return False
61
- # FIXME: Skip PointTypeField at this moment.
62
- if not isinstance (v .type , Structure ) \
63
- and isinstance (v .type (), (PointFieldType , MultiPolygonFieldType )):
69
+ try :
70
+ converted = convert_mongoengine_field (getattr (self .model , k ), self .registry )
71
+ except MongoEngineConversionError :
72
+ return False
73
+ if isinstance (converted , (ConnectionField , Dynamic , List )):
74
+ return False
75
+ if callable (getattr (converted , 'type' , None )) and isinstance (converted .type (),
76
+ (PointFieldType , MultiPolygonFieldType )):
64
77
return False
65
78
return True
66
79
@@ -69,7 +82,7 @@ def get_type(v):
69
82
return v .type .of_type ()
70
83
return v .type ()
71
84
72
- return {k : get_type (v ) for k , v in items if is_filterable (v )}
85
+ return {k : get_type (v ) for k , v in items if is_filterable (k )}
73
86
74
87
@property
75
88
def field_args (self ):
@@ -78,19 +91,26 @@ def field_args(self):
78
91
@property
79
92
def reference_args (self ):
80
93
def get_reference_field (r , kv ):
81
- if callable (getattr (kv [1 ], 'get_type' , None )):
82
- node = kv [1 ].get_type ()._type ._meta
83
- if not issubclass (node .model , mongoengine .EmbeddedDocument ):
84
- r .update ({kv [0 ]: node .fields ['id' ]._type .of_type ()})
94
+ field = kv [1 ]
95
+ mongo_field = getattr (self .model , kv [0 ], None )
96
+ if isinstance (mongo_field , (mongoengine .LazyReferenceField , mongoengine .ReferenceField )):
97
+ field = convert_mongoengine_field (mongo_field , self .registry )
98
+ if callable (getattr (field , 'get_type' , None )):
99
+ _type = field .get_type ()
100
+ if _type :
101
+ node = _type ._type ._meta
102
+ if 'id' in node .fields and not issubclass (node .model , mongoengine .EmbeddedDocument ):
103
+ r .update ({kv [0 ]: node .fields ['id' ]._type .of_type ()})
85
104
return r
105
+
86
106
return reduce (get_reference_field , self .fields .items (), {})
87
107
88
108
@property
89
109
def fields (self ):
90
110
return self ._type ._meta .fields
91
111
92
112
@classmethod
93
- def get_query (cls , model , info , ** args ):
113
+ def get_query (cls , model , connection , info , ** args ):
94
114
95
115
if not callable (getattr (model , 'objects' , None )):
96
116
return [], 0
@@ -102,20 +122,20 @@ def get_query(cls, model, info, **args):
102
122
for arg_name , arg in args .copy ().items ():
103
123
if arg_name in reference_fields :
104
124
reference_model = model ._fields [arg_name ]
105
- pk = from_global_id ( args .pop (arg_name ))[- 1 ]
125
+ pk = node_from_global_id ( connection , args .pop (arg_name ))[- 1 ]
106
126
reference_obj = reference_model .document_type_obj .objects (pk = pk ).get ()
107
127
reference_args [arg_name ] = reference_obj
108
128
109
129
args .update (reference_args )
110
130
first = args .pop ('first' , None )
111
131
last = args .pop ('last' , None )
112
- id = args .pop ('id' , None )
132
+ _id = args .pop ('id' , None )
113
133
before = args .pop ('before' , None )
114
134
after = args .pop ('after' , None )
115
135
116
- if id is not None :
136
+ if _id is not None :
117
137
# https://github.com/graphql-python/graphene/issues/124
118
- args ['pk' ] = from_global_id ( id )[- 1 ]
138
+ args ['pk' ] = node_from_global_id ( connection , _id )[- 1 ]
119
139
120
140
objs = objs .filter (** args )
121
141
@@ -152,14 +172,14 @@ def merge_querysets(cls, default_queryset, queryset):
152
172
def connection_resolver (cls , resolver , connection , model , root , info , ** args ):
153
173
iterable = resolver (root , info , ** args )
154
174
155
- if not iterable :
156
- iterable , _len = cls .get_query (model , info , ** args )
175
+ if iterable or iterable == []:
176
+ _len = len (iterable )
177
+ else :
178
+ iterable , _len = cls .get_query (model , connection , info , ** args )
157
179
158
180
if root :
159
181
# If we have a root, we must be at least 1 layer in, right?
160
182
_len = 0
161
- else :
162
- _len = len (iterable )
163
183
164
184
connection = connection_from_list_slice (
165
185
iterable ,
0 commit comments