Skip to content
Merged
3 changes: 3 additions & 0 deletions .github/workflows/pullrequests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ jobs:
- marshmallow==3.17.*
- marshmallow==3.18.*
- marshmallow>3.18.0
- marshmallow==4.0.0
- marshmallow==4.0.1
- marshmallow>4.0.1
flask:
- flask=='2.2.*' werkzeug=='2.2.*'
- flask=='2.3.*' werkzeug=='2.3.*'
Expand Down
27 changes: 27 additions & 0 deletions flask_rebar/compat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Any
from typing import Dict

try:
from importlib.metadata import version
except ImportError:
from importlib_metadata import version # type: ignore

import marshmallow
from marshmallow.fields import Field
from marshmallow.schema import Schema
Expand Down Expand Up @@ -64,3 +69,25 @@ def dump(schema: Schema, data: Dict[str, Any]) -> Dict[str, Any]:
def exclude_unknown_fields(schema: Schema) -> Schema:
schema.unknown = marshmallow.EXCLUDE
return schema


# Marshmallow version detection for backward compatibility
MARSHMALLOW_VERSION_MAJOR = int(version('marshmallow').split('.')[0])

def is_schema_ordered(schema: Schema) -> bool:
"""
Check if a schema should maintain field order.

In Marshmallow 3.x, this is controlled by the 'ordered' attribute.
In Marshmallow 4.x+, field order is always preserved (insertion order from dict).

:param Schema schema: The schema to check
:return: True if fields should maintain their order, False if they should be sorted
:rtype: bool
"""
if MARSHMALLOW_VERSION_MAJOR >= 4:
# In Marshmallow 4+, fields are always ordered (insertion order)
return True
else:
# In Marshmallow 3, check the 'ordered' attribute
return getattr(schema, 'ordered', False)
17 changes: 13 additions & 4 deletions flask_rebar/swagger_generation/marshmallow_to_swagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,13 @@ def get_schema_fields(schema: Schema) -> List[Tuple[str, m.fields.Field]]:
for name, field in schema.fields.items():
prop = compat.get_data_key(field)
fields.append((prop, field))
return sorted(fields)

# In Marshmallow 3.x, respect the 'ordered' Meta option for field ordering.
# When ordered=False (default), fields should be sorted alphabetically.
# In Marshmallow 4.0+, field order is always preserved (insertion order).
if not compat.is_schema_ordered(schema):
fields.sort()
return fields


class MarshmallowConverter(Generic[T]):
Expand Down Expand Up @@ -274,9 +280,12 @@ def get_required(
continue
required.append(prop)

if required and not obj.ordered:
required = sorted(required)
return required if required else UNSET
if not required:
return UNSET

if not compat.is_schema_ordered(obj):
required.sort()
return required

@sets_swagger_attr(sw.description)
def get_description(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

install_requires = [
"Flask>=1.0,<4",
"marshmallow>=3.0,<4",
"marshmallow>=3.0,<5",
"typing-extensions>=4.8,<5;python_version<'3.10'",
"Werkzeug>=2.2,<4",
]
Expand Down
9 changes: 8 additions & 1 deletion tests/swagger_generation/test_marshmallow_to_swagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from flask_rebar.validation import CommaSeparatedList
from flask_rebar.validation import QueryParamList
from flask_rebar.compat import MARSHMALLOW_VERSION_MAJOR


class StopLight(enum.Enum):
Expand Down Expand Up @@ -312,6 +313,9 @@ class Foo(m.Schema):
schema = Foo()
json_schema = self.registry.convert(schema)

# Compare required as a set since order may vary between marshmallow versions
required = json_schema.pop("required")
self.assertEqual(set(required), {"a", "b"})
self.assertEqual(
json_schema,
{
Expand All @@ -323,7 +327,6 @@ class Foo(m.Schema):
"a": {"type": "integer"},
"c": {"type": "integer"},
},
"required": ["a", "b"],
},
)

Expand Down Expand Up @@ -421,6 +424,10 @@ class Foo(m.Schema):
},
)

@pytest.mark.skipif(
MARSHMALLOW_VERSION_MAJOR >= 4,
reason="'self' nested reference removed in marshmallow 4.x"
)
def test_self_referential_nested_pre_3_3(self):
# Issue 90
# note for Marshmallow >= 3.3, preferred format is e.g.,:
Expand Down
103 changes: 70 additions & 33 deletions tests/test_rebar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,25 @@
import unittest

import marshmallow as m
import marshmallow_objects as mo
from flask import Flask, make_response

from parametrize import parametrize

# marshmallow-objects is archived and incompatible with marshmallow 4.x
# Make it optional for tests
try:
import marshmallow_objects as mo
# Test that marshmallow-objects actually works with current marshmallow version
# by trying to create a simple model - it fails with marshmallow 4.x
class _TestModel(mo.Model):
test = mo.fields.String()
_TestModel() # This will fail if incompatible
MARSHMALLOW_OBJECTS_AVAILABLE = True
del _TestModel
except (ImportError, TypeError):
mo = None # type: ignore
MARSHMALLOW_OBJECTS_AVAILABLE = False

from flask_rebar import messages
from flask_rebar import HeaderApiKeyAuthenticator, SwaggerV3Generator
from flask_rebar.compat import set_data_key
Expand All @@ -41,41 +55,67 @@ class FooSchema(m.Schema):
name = m.fields.String()


class FooModel(mo.Model):
uid = mo.fields.String()
name = mo.fields.String()


class ListOfFooSchema(m.Schema):
data = m.fields.Nested(FooSchema, many=True)


class ListOfFooModel(mo.Model):
data = mo.NestedModel(FooModel, many=True)


class FooUpdateSchema(m.Schema):
name = m.fields.String()


class FooUpdateModel(mo.Model):
name = mo.fields.String()


class FooListSchema(m.Schema):
name = m.fields.String(required=True)


class FooListModel(mo.Model):
name = mo.fields.String(required=True)


class HeadersSchema(m.Schema):
name = set_data_key(field=m.fields.String(required=True), key="x-name")


class HeadersModel(mo.Model):
name = set_data_key(field=mo.fields.String(required=True), key="x-name")
# marshmallow-objects Model classes (only defined if library is available)
if MARSHMALLOW_OBJECTS_AVAILABLE:
class FooModel(mo.Model):
uid = mo.fields.String()
name = mo.fields.String()

class ListOfFooModel(mo.Model):
data = mo.NestedModel(FooModel, many=True)

class FooUpdateModel(mo.Model):
name = mo.fields.String()

class FooListModel(mo.Model):
name = mo.fields.String(required=True)

class HeadersModel(mo.Model):
name = set_data_key(field=mo.fields.String(required=True), key="x-name")
else:
# Placeholders when marshmallow-objects is not available
FooModel = None # type: ignore
ListOfFooModel = None # type: ignore
FooUpdateModel = None # type: ignore
FooListModel = None # type: ignore
HeadersModel = None # type: ignore


# Parametrize test data - include model tests only when marshmallow-objects is available
_body_params_test_cases = [(FooSchema, FooUpdateSchema, False)]
_foo_update_cls_test_cases = [(FooUpdateSchema,)]
_list_of_foo_cls_test_cases = [(ListOfFooSchema,)]
_headers_cls_test_cases = [(HeadersSchema, False)]
_foo_definition_cls_test_cases = [(FooSchema,)]
_foo_definition_instance_test_cases = [(FooSchema(),)]
_headers_def_test_cases = [(HeadersSchema(), False)]
_schema_cls_test_cases = [(FooSchema, FooListSchema, HeadersSchema)]

if MARSHMALLOW_OBJECTS_AVAILABLE:
_body_params_test_cases.append((FooModel, FooUpdateModel, True))
_foo_update_cls_test_cases.append((FooUpdateModel,))
_list_of_foo_cls_test_cases.append((ListOfFooModel,))
_headers_cls_test_cases.append((HeadersModel, True))
_foo_definition_cls_test_cases.append((FooModel,))
_foo_definition_instance_test_cases.append((FooModel(),))
_headers_def_test_cases.append((HeadersModel, True))
_schema_cls_test_cases.append((FooModel, FooListModel, HeadersModel))


class MeSchema(m.Schema):
Expand Down Expand Up @@ -241,7 +281,7 @@ def test_override_with_no_authenticator(self):

@parametrize(
"foo_cls,foo_update_cls,use_model",
[(FooSchema, FooUpdateSchema, False), (FooModel, FooUpdateModel, True)],
_body_params_test_cases,
)
def test_validate_body_parameters(self, foo_cls, foo_update_cls, use_model):
rebar = Rebar()
Expand Down Expand Up @@ -278,7 +318,7 @@ def update_foo(foo_uid):
)
self.assertEqual(resp.status_code, 400)

@parametrize("foo_update_cls", [(FooUpdateSchema,), (FooUpdateModel,)])
@parametrize("foo_update_cls", _foo_update_cls_test_cases)
def test_flask_response_instance_interop_body_matches_schema(self, foo_update_cls):
rebar = Rebar()
registry = rebar.create_handler_registry()
Expand All @@ -293,7 +333,7 @@ def foo():
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.headers["foo"], "bar")

@parametrize("foo_update_cls", [(FooUpdateSchema,), (FooUpdateModel,)])
@parametrize("foo_update_cls", _foo_update_cls_test_cases)
def test_flask_response_instance_interop_body_does_not_match_schema(
self, foo_update_cls
):
Expand Down Expand Up @@ -322,7 +362,7 @@ def foo():
self.assertEqual(resp.status_code, 302)
self.assertEqual(resp.headers["Location"], "http://foo.com")

@parametrize("list_of_foo_cls", [(ListOfFooSchema,), (ListOfFooModel,)])
@parametrize("list_of_foo_cls", _list_of_foo_cls_test_cases)
def test_validate_query_parameters(self, list_of_foo_cls):
rebar = Rebar()
registry = rebar.create_handler_registry()
Expand All @@ -346,7 +386,7 @@ def list_foos():
self.assertEqual(resp.status_code, 400)

@parametrize(
"headers_cls, use_model", [(HeadersSchema, False), (HeadersModel, True)]
"headers_cls, use_model", _headers_cls_test_cases
)
def test_validate_headers(self, headers_cls, use_model):
rebar = Rebar()
Expand Down Expand Up @@ -473,7 +513,7 @@ def delete_me():
self.assertEqual(resp.data.decode("utf-8"), "")
self.assertEqual(resp.headers["Content-Type"], "handler/type")

@parametrize("foo_definition", [(FooSchema,), (FooModel,)])
@parametrize("foo_definition", _foo_definition_cls_test_cases)
def test_view_function_tuple_response(self, foo_definition):
header_key = "X-Foo"
header_value = "bar"
Expand Down Expand Up @@ -570,7 +610,7 @@ def test_swagger_can_be_set_to_v3(self):
resp = app.test_client().get("/swagger/ui/")
self.assertEqual(resp.status_code, 200)

@parametrize("foo_definition", [(FooSchema(),), (FooModel(),)])
@parametrize("foo_definition", _foo_definition_instance_test_cases)
def test_register_multiple_paths(self, foo_definition):
rebar = Rebar()
registry = rebar.create_handler_registry()
Expand All @@ -596,7 +636,7 @@ def handler_func(foo_uid):
self.assertIn("/bars/{foo_uid}", swagger["paths"])
self.assertIn("/foos/{foo_uid}", swagger["paths"])

@parametrize("foo_definition", [(FooSchema(),), (FooModel(),)])
@parametrize("foo_definition", _foo_definition_instance_test_cases)
def test_register_multiple_methods(self, foo_definition):
rebar = Rebar()
registry = rebar.create_handler_registry()
Expand Down Expand Up @@ -629,7 +669,7 @@ def handler_func(foo_uid):
self.assertIn("patch", swagger["paths"]["/foos/{foo_uid}"])

@parametrize(
"headers_def, use_model", [(HeadersSchema(), False), (HeadersModel, True)]
"headers_def, use_model", _headers_def_test_cases
)
def test_default_headers(self, headers_def, use_model):
rebar = Rebar()
Expand Down Expand Up @@ -784,10 +824,7 @@ def test_redirects_for_missing_trailing_slash(self):

@parametrize(
"foo_cls, foo_list_cls, headers_cls",
[
(FooSchema, FooListSchema, HeadersSchema),
(FooModel, FooListModel, HeadersModel),
],
_schema_cls_test_cases,
)
def test_bare_class_schemas_handled(self, foo_cls, foo_list_cls, headers_cls):
rebar = Rebar()
Expand Down
8 changes: 5 additions & 3 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,11 @@ def test_validation_works(self):
self.data["validation_required"] = "123"
with self.assertRaises(ValidationError) as ctx:
compat.dump(self.validated_schema, self.data)
# it's some sort of date error
self.assertIn(
"'str' object has no attribute 'isoformat'", ctx.exception.messages[0]
# it's some sort of date error - message format varies by marshmallow version
error_msg = str(ctx.exception.messages)
self.assertTrue(
"isoformat" in error_msg,
f"Expected 'isoformat' error message, got: {error_msg}"
)

def test_required_failed_validate(self):
Expand Down
Loading