Skip to content

Commit 331c313

Browse files
Add rounding parameter to DecimalField (#5562)
* Adding rounding parameter to DecimalField. * Using standard `assert` instead of `self.fail()`. * add testcase and PEP8 multilines fix * flake8 fixes * Use decimal module constants in tests. * Add docs note for `rounding` parameter.
1 parent 565c722 commit 331c313

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

docs/api-guide/fields.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ Corresponds to `django.db.models.fields.DecimalField`.
269269
- `max_value` Validate that the number provided is no greater than this value.
270270
- `min_value` Validate that the number provided is no less than this value.
271271
- `localize` Set to `True` to enable localization of input and output based on the current locale. This will also force `coerce_to_string` to `True`. Defaults to `False`. Note that data formatting is enabled if you have set `USE_L10N=True` in your settings file.
272+
- `rounding` Sets the rounding mode used when quantising to the configured precision. Valid values are [`decimal` module rounding modes][python-decimal-rounding-modes]. Defaults to `None`.
272273

273274
#### Example usage
274275

@@ -680,3 +681,4 @@ The [django-rest-framework-hstore][django-rest-framework-hstore] package provide
680681
[django-rest-framework-gis]: https://github.com/djangonauts/django-rest-framework-gis
681682
[django-rest-framework-hstore]: https://github.com/djangonauts/django-rest-framework-hstore
682683
[django-hstore]: https://github.com/djangonauts/django-hstore
684+
[python-decimal-rounding-modes]: https://docs.python.org/3/library/decimal.html#rounding-modes

rest_framework/fields.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -997,7 +997,7 @@ class DecimalField(Field):
997997
MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
998998

999999
def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None,
1000-
localize=False, **kwargs):
1000+
localize=False, rounding=None, **kwargs):
10011001
self.max_digits = max_digits
10021002
self.decimal_places = decimal_places
10031003
self.localize = localize
@@ -1029,6 +1029,12 @@ def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=
10291029
self.validators.append(
10301030
MinValueValidator(self.min_value, message=message))
10311031

1032+
if rounding is not None:
1033+
valid_roundings = [v for k, v in vars(decimal).items() if k.startswith('ROUND_')]
1034+
assert rounding in valid_roundings, (
1035+
'Invalid rounding option %s. Valid values for rounding are: %s' % (rounding, valid_roundings))
1036+
self.rounding = rounding
1037+
10321038
def to_internal_value(self, data):
10331039
"""
10341040
Validate that the input is a decimal number and return a Decimal
@@ -1121,6 +1127,7 @@ def quantize(self, value):
11211127
context.prec = self.max_digits
11221128
return value.quantize(
11231129
decimal.Decimal('.1') ** self.decimal_places,
1130+
rounding=self.rounding,
11241131
context=context
11251132
)
11261133

tests/test_fields.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
import unittest
55
import uuid
6-
from decimal import Decimal
6+
from decimal import ROUND_DOWN, ROUND_UP, Decimal
77

88
import django
99
import pytest
@@ -1092,8 +1092,21 @@ class TestNoDecimalPlaces(FieldValues):
10921092
field = serializers.DecimalField(max_digits=6, decimal_places=None)
10931093

10941094

1095-
# Date & time serializers...
1095+
class TestRoundingDecimalField(TestCase):
1096+
def test_valid_rounding(self):
1097+
field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_UP)
1098+
assert field.to_representation(Decimal('1.234')) == '1.24'
1099+
1100+
field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_DOWN)
1101+
assert field.to_representation(Decimal('1.234')) == '1.23'
1102+
1103+
def test_invalid_rounding(self):
1104+
with pytest.raises(AssertionError) as excinfo:
1105+
serializers.DecimalField(max_digits=1, decimal_places=1, rounding='ROUND_UNKNOWN')
1106+
assert 'Invalid rounding option' in str(excinfo.value)
10961107

1108+
1109+
# Date & time serializers...
10971110
class TestDateField(FieldValues):
10981111
"""
10991112
Valid and invalid values for `DateField`.

0 commit comments

Comments
 (0)