1
1
from functools import partial
2
-
2
+ from promise import is_thenable , Promise
3
3
from sqlalchemy .orm .query import Query
4
4
5
5
from graphene .relay import ConnectionField
@@ -25,39 +25,38 @@ def get_query(cls, model, info, sort=None, **args):
25
25
query = query .order_by (* (col .value for col in sort ))
26
26
return query
27
27
28
- @property
29
- def type (self ):
30
- from .types import SQLAlchemyObjectType
31
- _type = super (ConnectionField , self ).type
32
- assert issubclass (_type , SQLAlchemyObjectType ), (
33
- "SQLAlchemyConnectionField only accepts SQLAlchemyObjectType types"
34
- )
35
- assert _type ._meta .connection , "The type {} doesn't have a connection" .format (_type .__name__ )
36
- return _type ._meta .connection
37
-
38
28
@classmethod
39
- def connection_resolver (cls , resolver , connection , model , root , info , ** args ):
40
- iterable = resolver (root , info , ** args )
41
- if iterable is None :
42
- iterable = cls .get_query (model , info , ** args )
43
- if isinstance (iterable , Query ):
44
- _len = iterable .count ()
29
+ def resolve_connection (cls , connection_type , model , info , args , resolved ):
30
+ if resolved is None :
31
+ resolved = cls .get_query (model , info , ** args )
32
+ if isinstance (resolved , Query ):
33
+ _len = resolved .count ()
45
34
else :
46
- _len = len (iterable )
35
+ _len = len (resolved )
47
36
connection = connection_from_list_slice (
48
- iterable ,
37
+ resolved ,
49
38
args ,
50
39
slice_start = 0 ,
51
40
list_length = _len ,
52
41
list_slice_length = _len ,
53
- connection_type = connection ,
42
+ connection_type = connection_type ,
54
43
pageinfo_type = PageInfo ,
55
- edge_type = connection .Edge ,
44
+ edge_type = connection_type .Edge ,
56
45
)
57
- connection .iterable = iterable
46
+ connection .iterable = resolved
58
47
connection .length = _len
59
48
return connection
60
49
50
+ @classmethod
51
+ def connection_resolver (cls , resolver , connection_type , model , root , info , ** args ):
52
+ resolved = resolver (root , info , ** args )
53
+
54
+ on_resolve = partial (cls .resolve_connection , connection_type , model , info , args )
55
+ if is_thenable (resolved ):
56
+ return Promise .resolve (resolved ).then (on_resolve )
57
+
58
+ return on_resolve (resolved )
59
+
61
60
def get_resolver (self , parent_resolver ):
62
61
return partial (self .connection_resolver , parent_resolver , self .type , self .model )
63
62
0 commit comments