1
1
from functools import partial
2
-
2
+ from promise import is_thenable , Promise
3
3
from sqlalchemy .orm .query import Query
4
4
5
5
from graphene .relay import ConnectionField
@@ -19,39 +19,38 @@ def model(self):
19
19
def get_query (cls , model , info , ** args ):
20
20
return get_query (model , info .context )
21
21
22
- @property
23
- def type (self ):
24
- from .types import SQLAlchemyObjectType
25
- _type = super (ConnectionField , self ).type
26
- assert issubclass (_type , SQLAlchemyObjectType ), (
27
- "SQLAlchemyConnectionField only accepts SQLAlchemyObjectType types"
28
- )
29
- assert _type ._meta .connection , "The type {} doesn't have a connection" .format (_type .__name__ )
30
- return _type ._meta .connection
31
-
32
22
@classmethod
33
- def connection_resolver (cls , resolver , connection , model , root , info , ** args ):
34
- iterable = resolver (root , info , ** args )
35
- if iterable is None :
36
- iterable = cls .get_query (model , info , ** args )
37
- if isinstance (iterable , Query ):
38
- _len = iterable .count ()
23
+ def resolve_connection (cls , connection_type , model , info , args , resolved ):
24
+ if resolved is None :
25
+ resolved = cls .get_query (model , info , ** args )
26
+ if isinstance (resolved , Query ):
27
+ _len = resolved .count ()
39
28
else :
40
- _len = len (iterable )
29
+ _len = len (resolved )
41
30
connection = connection_from_list_slice (
42
- iterable ,
31
+ resolved ,
43
32
args ,
44
33
slice_start = 0 ,
45
34
list_length = _len ,
46
35
list_slice_length = _len ,
47
- connection_type = connection ,
36
+ connection_type = connection_type ,
48
37
pageinfo_type = PageInfo ,
49
- edge_type = connection .Edge ,
38
+ edge_type = connection_type .Edge ,
50
39
)
51
- connection .iterable = iterable
40
+ connection .iterable = resolved
52
41
connection .length = _len
53
42
return connection
54
43
44
+ @classmethod
45
+ def connection_resolver (cls , resolver , connection_type , model , root , info , ** args ):
46
+ resolved = resolver (root , info , ** args )
47
+
48
+ on_resolve = partial (cls .resolve_connection , connection_type , model , info , args )
49
+ if is_thenable (resolved ):
50
+ return Promise .resolve (resolved ).then (on_resolve )
51
+
52
+ return on_resolve (resolved )
53
+
55
54
def get_resolver (self , parent_resolver ):
56
55
return partial (self .connection_resolver , parent_resolver , self .type , self .model )
57
56
0 commit comments