|
1 | 1 | from functools import partial
|
2 | 2 |
|
3 | 3 | from django.db.models.query import QuerySet
|
| 4 | +from graphene import NonNull |
4 | 5 |
|
5 | 6 | from promise import Promise
|
6 | 7 |
|
@@ -45,17 +46,31 @@ def type(self):
|
45 | 46 | from .types import DjangoObjectType
|
46 | 47 |
|
47 | 48 | _type = super(ConnectionField, self).type
|
| 49 | + non_null = False |
| 50 | + if isinstance(_type, NonNull): |
| 51 | + _type = _type.of_type |
| 52 | + non_null = True |
48 | 53 | assert issubclass(
|
49 | 54 | _type, DjangoObjectType
|
50 | 55 | ), "DjangoConnectionField only accepts DjangoObjectType types"
|
51 | 56 | assert _type._meta.connection, "The type {} doesn't have a connection".format(
|
52 | 57 | _type.__name__
|
53 | 58 | )
|
54 |
| - return _type._meta.connection |
| 59 | + connection_type = _type._meta.connection |
| 60 | + if non_null: |
| 61 | + return NonNull(connection_type) |
| 62 | + return connection_type |
| 63 | + |
| 64 | + @property |
| 65 | + def connection_type(self): |
| 66 | + type = self.type |
| 67 | + if isinstance(type, NonNull): |
| 68 | + return type.of_type |
| 69 | + return type |
55 | 70 |
|
56 | 71 | @property
|
57 | 72 | def node_type(self):
|
58 |
| - return self.type._meta.node |
| 73 | + return self.connection_type._meta.node |
59 | 74 |
|
60 | 75 | @property
|
61 | 76 | def model(self):
|
@@ -103,15 +118,15 @@ def resolve_connection(cls, connection, default_manager, args, iterable):
|
103 | 118 |
|
104 | 119 | @classmethod
|
105 | 120 | def connection_resolver(
|
106 |
| - cls, |
107 |
| - resolver, |
108 |
| - connection, |
109 |
| - default_manager, |
110 |
| - max_limit, |
111 |
| - enforce_first_or_last, |
112 |
| - root, |
113 |
| - info, |
114 |
| - **args |
| 121 | + cls, |
| 122 | + resolver, |
| 123 | + connection, |
| 124 | + default_manager, |
| 125 | + max_limit, |
| 126 | + enforce_first_or_last, |
| 127 | + root, |
| 128 | + info, |
| 129 | + **args |
115 | 130 | ):
|
116 | 131 | first = args.get("first")
|
117 | 132 | last = args.get("last")
|
@@ -146,7 +161,7 @@ def get_resolver(self, parent_resolver):
|
146 | 161 | return partial(
|
147 | 162 | self.connection_resolver,
|
148 | 163 | parent_resolver,
|
149 |
| - self.type, |
| 164 | + self.connection_type, |
150 | 165 | self.get_manager(),
|
151 | 166 | self.max_limit,
|
152 | 167 | self.enforce_first_or_last,
|
|
0 commit comments