Skip to content

Commit fc58f3f

Browse files
decazsloria
authored andcommitted
Use callable as schema within @use_kwargs (#79)
* Change `resolve_instance` function (also rename to `resolve_resource`) * Fix schema handling within the `Converter` class * Fix resolving schema withing the `Wrapper` class * Add .idea directory to the .gitignore * Add tests * Cosmetic fix
1 parent 17333a8 commit fc58f3f

File tree

6 files changed

+66
-17
lines changed

6 files changed

+66
-17
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,6 @@ docs/_build
4747
README.html
4848

4949
_sandbox
50+
51+
# JetBrains
52+
.idea/

flask_apispec/apidoc.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from marshmallow.utils import is_instance_or_subclass
1414

1515
from flask_apispec.paths import rule_to_path, rule_to_params
16-
from flask_apispec.utils import resolve_instance, resolve_annotations, merge_recursive
16+
from flask_apispec.utils import resolve_resource, resolve_annotations, merge_recursive
1717

1818
class Converter(object):
1919

@@ -59,11 +59,17 @@ def get_parent(self, view):
5959
def get_parameters(self, rule, view, docs, parent=None):
6060
annotation = resolve_annotations(view, 'args', parent)
6161
args = merge_recursive(annotation.options)
62-
converter = (
63-
swagger.schema2parameters
64-
if is_instance_or_subclass(args.get('args', {}), Schema)
65-
else swagger.fields2parameters
66-
)
62+
schema = args.get('args', {})
63+
if is_instance_or_subclass(schema, Schema):
64+
converter = swagger.schema2parameters
65+
elif callable(schema):
66+
schema = schema(request=None)
67+
if is_instance_or_subclass(schema, Schema):
68+
converter = swagger.schema2parameters
69+
else:
70+
converter = swagger.fields2parameters
71+
else:
72+
converter = swagger.fields2parameters
6773
options = copy.copy(args.get('kwargs', {}))
6874
locations = options.pop('locations', None)
6975
if locations:
@@ -72,10 +78,7 @@ def get_parameters(self, rule, view, docs, parent=None):
7278
options['dump'] = False
7379

7480
rule_params = rule_to_params(rule, docs.get('params')) or []
75-
extra_params = converter(
76-
args.get('args', {}),
77-
**options
78-
) if args else []
81+
extra_params = converter(schema, **options) if args else []
7982

8083
return extra_params + rule_params
8184

@@ -98,4 +101,4 @@ def get_operations(self, rule, resource):
98101
}
99102

100103
def get_parent(self, resource, **kwargs):
101-
return resolve_instance(resource, **kwargs)
104+
return resolve_resource(resource, **kwargs)

flask_apispec/utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,20 @@
33
import functools
44

55
import six
6+
import marshmallow as ma
67

7-
def resolve_instance(schema, **kwargs):
8-
kwargs = kwargs or {}
8+
def resolve_resource(resource, **kwargs):
99
resource_class_args = kwargs.get('resource_class_args') or ()
1010
resource_class_kwargs = kwargs.get('resource_class_kwargs') or {}
11-
if isinstance(schema, type):
12-
return schema(*resource_class_args, **resource_class_kwargs)
11+
if isinstance(resource, type):
12+
return resource(*resource_class_args, **resource_class_kwargs)
13+
return resource
14+
15+
def resolve_schema(schema, request=None):
16+
if isinstance(schema, type) and issubclass(schema, ma.Schema):
17+
schema = schema()
18+
elif callable(schema):
19+
schema = schema(request)
1320
return schema
1421

1522
class Ref(object):

flask_apispec/wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def call_view(self, *args, **kwargs):
3333
annotation = utils.resolve_annotations(self.func, 'args', self.instance)
3434
if annotation.apply is not False:
3535
for option in annotation.options:
36-
schema = utils.resolve_instance(option['args'])
36+
schema = utils.resolve_schema(option['args'], request=flask.request)
3737
parsed = parser.parse(schema, locations=option['kwargs']['locations'])
3838
if getattr(schema, 'many', False):
3939
args += tuple(parsed)
@@ -48,7 +48,7 @@ def marshal_result(self, unpacked, status_code):
4848
schemas = utils.merge_recursive(annotation.options)
4949
schema = schemas.get(status_code, schemas.get('default'))
5050
if schema and annotation.apply is not False:
51-
schema = utils.resolve_instance(schema['schema'])
51+
schema = utils.resolve_schema(schema['schema'], request=flask.request)
5252
output = schema.dump(unpacked[0]).data
5353
else:
5454
output = unpacked[0]

tests/test_swagger.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,22 @@ def test_params(self, app, path):
9191
)
9292
assert params == expected
9393

94+
class TestCallableAsArgSchema(TestArgSchema):
95+
96+
@pytest.fixture
97+
def function_view(self, app, models, schemas):
98+
def schema_factory(request):
99+
class ArgSchema(Schema):
100+
name = fields.Str()
101+
102+
return ArgSchema
103+
104+
@app.route('/bands/<int:band_id>/')
105+
@use_kwargs(schema_factory, locations=('query', ))
106+
def get_band(**kwargs):
107+
return kwargs
108+
return get_band
109+
94110
class TestDeleteView:
95111

96112
@pytest.fixture

tests/test_views.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,26 @@ def view(**kwargs):
5454
res = client.get('/', {'name': 'freddie', 'instrument': 'vocals'})
5555
assert res.json == {'name': 'freddie', 'instrument': 'vocals'}
5656

57+
def test_use_kwargs_callable_as_schema(self, app, client):
58+
def schema_factory(request):
59+
assert request.method == 'GET'
60+
assert request.path == '/'
61+
62+
class ArgSchema(Schema):
63+
name = fields.Str()
64+
65+
class Meta:
66+
strict = True
67+
68+
return ArgSchema
69+
70+
@app.route('/')
71+
@use_kwargs(schema_factory)
72+
def view(**kwargs):
73+
return kwargs
74+
res = client.get('/', {'name': 'freddie'})
75+
assert res.json == {'name': 'freddie'}
76+
5777
def test_marshal_with_default(self, app, client, models, schemas):
5878
@app.route('/')
5979
@marshal_with(schemas.BandSchema)

0 commit comments

Comments
 (0)