Skip to content

Commit 603aac7

Browse files
authored
Corrected OpenAPI schema type for DecimalField (#7254)
1 parent 41f27c3 commit 603aac7

File tree

3 files changed

+30
-8
lines changed

3 files changed

+30
-8
lines changed

rest_framework/schemas/openapi.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from rest_framework import exceptions, renderers, serializers
1616
from rest_framework.compat import uritemplate
1717
from rest_framework.fields import _UnvalidatedField, empty
18+
from rest_framework.settings import api_settings
1819

1920
from .generators import BaseSchemaGenerator
2021
from .inspectors import ViewInspector
@@ -446,11 +447,17 @@ def _map_field(self, field):
446447
content['format'] = field.protocol
447448
return content
448449

449-
# DecimalField has multipleOf based on decimal_places
450450
if isinstance(field, serializers.DecimalField):
451-
content = {
452-
'type': 'number'
453-
}
451+
if getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
452+
content = {
453+
'type': 'string',
454+
'format': 'decimal',
455+
}
456+
else:
457+
content = {
458+
'type': 'number'
459+
}
460+
454461
if field.decimal_places:
455462
content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1')
456463
if field.max_whole_digits:
@@ -461,7 +468,7 @@ def _map_field(self, field):
461468

462469
if isinstance(field, serializers.FloatField):
463470
content = {
464-
'type': 'number'
471+
'type': 'number',
465472
}
466473
self._map_min_max(field, content)
467474
return content
@@ -560,7 +567,8 @@ def _map_field_validators(self, field, schema):
560567
schema['maximum'] = v.limit_value
561568
elif isinstance(v, MinValueValidator):
562569
schema['minimum'] = v.limit_value
563-
elif isinstance(v, DecimalValidator):
570+
elif isinstance(v, DecimalValidator) and \
571+
not getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING):
564572
if v.decimal_places:
565573
schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1')
566574
if v.max_digits:

tests/schemas/test_openapi.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,16 @@ def test_serializer_validators(self):
838838
assert properties['decimal2']['type'] == 'number'
839839
assert properties['decimal2']['multipleOf'] == .0001
840840

841+
assert properties['decimal3'] == {
842+
'type': 'string', 'format': 'decimal', 'maximum': 1000000, 'minimum': -1000000, 'multipleOf': 0.01
843+
}
844+
assert properties['decimal4'] == {
845+
'type': 'string', 'format': 'decimal', 'maximum': 1000000, 'minimum': -1000000, 'multipleOf': 0.01
846+
}
847+
assert properties['decimal5'] == {
848+
'type': 'string', 'format': 'decimal', 'maximum': 10000, 'minimum': -10000, 'multipleOf': 0.01
849+
}
850+
841851
assert properties['email']['type'] == 'string'
842852
assert properties['email']['format'] == 'email'
843853
assert properties['email']['default'] == '[email protected]'

tests/schemas/views.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,13 @@ class ExampleValidatedSerializer(serializers.Serializer):
119119
MinLengthValidator(limit_value=2),
120120
)
121121
)
122-
decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2)
123-
decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0,
122+
decimal1 = serializers.DecimalField(max_digits=6, decimal_places=2, coerce_to_string=False)
123+
decimal2 = serializers.DecimalField(max_digits=5, decimal_places=0, coerce_to_string=False,
124124
validators=(DecimalValidator(max_digits=17, decimal_places=4),))
125+
decimal3 = serializers.DecimalField(max_digits=8, decimal_places=2, coerce_to_string=True)
126+
decimal4 = serializers.DecimalField(max_digits=8, decimal_places=2, coerce_to_string=True,
127+
validators=(DecimalValidator(max_digits=17, decimal_places=4),))
128+
decimal5 = serializers.DecimalField(max_digits=6, decimal_places=2)
125129
email = serializers.EmailField(default='[email protected]')
126130
url = serializers.URLField(default='http://www.example.com', allow_null=True)
127131
uuid = serializers.UUIDField()

0 commit comments

Comments
 (0)