Skip to content

Commit 98825fa

Browse files
committed
Added optional default_resolver to ObjectType.
1 parent 62e58bd commit 98825fa

File tree

4 files changed

+73
-6
lines changed

4 files changed

+73
-6
lines changed

graphene/types/objecttype.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ def __new__(cls, name, bases, attrs):
1919
if not is_base_type(bases, ObjectTypeMeta):
2020
return type.__new__(cls, name, bases, attrs)
2121

22-
_meta = attrs.pop('_meta', None)
23-
options = _meta or Options(
22+
attrs.pop('_meta', None)
23+
options = Options(
2424
attrs.pop('Meta', None),
2525
name=name,
2626
description=trim_docstring(attrs.get('__doc__')),
2727
interfaces=(),
28+
default_resolver=None,
2829
local_fields=OrderedDict(),
2930
)
3031
options.base_fields = get_base_fields(bases, _as=Field)

graphene/types/resolver.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
def attr_resolver(attname, default_value, root, args, context, info):
2+
return getattr(root, attname, default_value)
3+
4+
5+
def dict_resolver(attname, default_value, root, args, context, info):
6+
return root.get(attname, default_value)
7+
8+
9+
default_resolver = attr_resolver
10+
11+
12+
def set_default_resolver(resolver):
13+
global default_resolver
14+
assert callable(resolver), 'Received non-callable resolver.'
15+
default_resolver = resolver
16+
17+
18+
def get_default_resolver():
19+
return default_resolver

graphene/types/tests/test_resolver.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import pytest
2+
3+
from ..resolver import attr_resolver, dict_resolver, get_default_resolver, set_default_resolver
4+
5+
args = {}
6+
context = None
7+
info = None
8+
9+
demo_dict = {
10+
'attr': 'value'
11+
}
12+
13+
14+
class demo_obj(object):
15+
attr = 'value'
16+
17+
18+
def test_attr_resolver():
19+
resolved = attr_resolver('attr', None, demo_obj, args, context, info)
20+
assert resolved == 'value'
21+
22+
23+
def test_attr_resolver_default_value():
24+
resolved = attr_resolver('attr2', 'default', demo_obj, args, context, info)
25+
assert resolved == 'default'
26+
27+
28+
def test_dict_resolver():
29+
resolved = dict_resolver('attr', None, demo_dict, args, context, info)
30+
assert resolved == 'value'
31+
32+
33+
def test_dict_resolver_default_value():
34+
resolved = dict_resolver('attr2', 'default', demo_dict, args, context, info)
35+
assert resolved == 'default'
36+
37+
38+
def test_get_default_resolver_is_attr_resolver():
39+
assert get_default_resolver() == attr_resolver
40+
41+
42+
def test_set_default_resolver_workd():
43+
default_resolver = get_default_resolver()
44+
45+
set_default_resolver(dict_resolver)
46+
assert get_default_resolver() == dict_resolver
47+
48+
set_default_resolver(default_resolver)

graphene/types/typemap.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .inputobjecttype import InputObjectType
2222
from .interface import Interface
2323
from .objecttype import ObjectType
24+
from .resolver import get_default_resolver
2425
from .scalars import ID, Boolean, Float, Int, Scalar, String
2526
from .structures import List, NonNull
2627
from .union import Union
@@ -205,9 +206,6 @@ def get_name(self, name):
205206
return to_camel_case(name)
206207
return name
207208

208-
def default_resolver(self, attname, default_value, root, *_):
209-
return getattr(root, attname, default_value)
210-
211209
def construct_fields_for_type(self, map, type, is_input_type=False):
212210
fields = OrderedDict()
213211
for name, field in type._meta.fields.items():
@@ -267,7 +265,8 @@ def get_resolver_for_type(self, type, name, default_value):
267265
if resolver:
268266
return get_unbound_function(resolver)
269267

270-
return partial(self.default_resolver, name, default_value)
268+
default_resolver = type._meta.default_resolver or get_default_resolver()
269+
return partial(default_resolver, name, default_value)
271270

272271
def get_field_type(self, map, type):
273272
if isinstance(type, List):

0 commit comments

Comments
 (0)