Skip to content

Commit ef3e813

Browse files
Allow validation against max_digits and decimals to pass if normalized or non-normalized input is valid (#1049)
1 parent dd75669 commit ef3e813

File tree

2 files changed

+104
-53
lines changed

2 files changed

+104
-53
lines changed

src/validators/decimal.rs

Lines changed: 76 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,41 @@ impl_py_gc_traverse!(DecimalValidator {
8383
gt
8484
});
8585

86+
fn extract_decimal_digits_info<'data>(
87+
decimal: &PyAny,
88+
normalized: bool,
89+
py: Python<'data>,
90+
) -> ValResult<'data, (u64, u64)> {
91+
let mut normalized_decimal: Option<&PyAny> = None;
92+
if normalized {
93+
normalized_decimal = Some(decimal.call_method0(intern!(py, "normalize")).unwrap_or(decimal));
94+
}
95+
let (_, digit_tuple, exponent): (&PyAny, &PyTuple, &PyAny) = normalized_decimal
96+
.unwrap_or(decimal)
97+
.call_method0(intern!(py, "as_tuple"))?
98+
.extract()?;
99+
100+
// finite values have numeric exponent, we checked is_finite above
101+
let exponent: i64 = exponent.extract()?;
102+
let mut digits: u64 = u64::try_from(digit_tuple.len()).map_err(|e| ValError::InternalErr(e.into()))?;
103+
let decimals;
104+
if exponent >= 0 {
105+
// A positive exponent adds that many trailing zeros.
106+
digits += exponent as u64;
107+
decimals = 0;
108+
} else {
109+
// If the absolute value of the negative exponent is larger than the
110+
// number of digits, then it's the same as the number of digits,
111+
// because it'll consume all the digits in digit_tuple and then
112+
// add abs(exponent) - len(digit_tuple) leading zeros after the
113+
// decimal point.
114+
decimals = exponent.unsigned_abs();
115+
digits = digits.max(decimals);
116+
}
117+
118+
Ok((decimals, digits))
119+
}
120+
86121
impl Validator for DecimalValidator {
87122
fn validate<'data>(
88123
&self,
@@ -98,65 +133,53 @@ impl Validator for DecimalValidator {
98133
}
99134

100135
if self.check_digits {
101-
let normalized_value = decimal.call_method0(intern!(py, "normalize")).unwrap_or(decimal);
102-
let (_, digit_tuple, exponent): (&PyAny, &PyTuple, &PyAny) =
103-
normalized_value.call_method0(intern!(py, "as_tuple"))?.extract()?;
136+
if let Ok((normalized_decimals, normalized_digits)) = extract_decimal_digits_info(decimal, true, py) {
137+
if let Ok((decimals, digits)) = extract_decimal_digits_info(decimal, false, py) {
138+
if let Some(max_digits) = self.max_digits {
139+
if (digits > max_digits) & (normalized_digits > max_digits) {
140+
return Err(ValError::new(
141+
ErrorType::DecimalMaxDigits {
142+
max_digits,
143+
context: None,
144+
},
145+
input,
146+
));
147+
}
148+
}
104149

105-
// finite values have numeric exponent, we checked is_finite above
106-
let exponent: i64 = exponent.extract()?;
107-
let mut digits: u64 = u64::try_from(digit_tuple.len()).map_err(|e| ValError::InternalErr(e.into()))?;
108-
let decimals;
109-
if exponent >= 0 {
110-
// A positive exponent adds that many trailing zeros.
111-
digits += exponent as u64;
112-
decimals = 0;
113-
} else {
114-
// If the absolute value of the negative exponent is larger than the
115-
// number of digits, then it's the same as the number of digits,
116-
// because it'll consume all the digits in digit_tuple and then
117-
// add abs(exponent) - len(digit_tuple) leading zeros after the
118-
// decimal point.
119-
decimals = exponent.unsigned_abs();
120-
digits = digits.max(decimals);
121-
}
150+
if let Some(decimal_places) = self.decimal_places {
151+
if (decimals > decimal_places) & (normalized_decimals > decimal_places) {
152+
return Err(ValError::new(
153+
ErrorType::DecimalMaxPlaces {
154+
decimal_places,
155+
context: None,
156+
},
157+
input,
158+
));
159+
}
122160

123-
if let Some(max_digits) = self.max_digits {
124-
if digits > max_digits {
125-
return Err(ValError::new(
126-
ErrorType::DecimalMaxDigits {
127-
max_digits,
128-
context: None,
129-
},
130-
input,
131-
));
132-
}
133-
}
161+
if let Some(max_digits) = self.max_digits {
162+
let whole_digits = digits.saturating_sub(decimals);
163+
let max_whole_digits = max_digits.saturating_sub(decimal_places);
134164

135-
if let Some(decimal_places) = self.decimal_places {
136-
if decimals > decimal_places {
137-
return Err(ValError::new(
138-
ErrorType::DecimalMaxPlaces {
139-
decimal_places,
140-
context: None,
141-
},
142-
input,
143-
));
144-
}
165+
let normalized_whole_digits = normalized_digits.saturating_sub(normalized_decimals);
166+
let normalized_max_whole_digits = max_digits.saturating_sub(decimal_places);
145167

146-
if let Some(max_digits) = self.max_digits {
147-
let whole_digits = digits.saturating_sub(decimals);
148-
let max_whole_digits = max_digits.saturating_sub(decimal_places);
149-
if whole_digits > max_whole_digits {
150-
return Err(ValError::new(
151-
ErrorType::DecimalWholeDigits {
152-
whole_digits: max_whole_digits,
153-
context: None,
154-
},
155-
input,
156-
));
168+
if (whole_digits > max_whole_digits)
169+
& (normalized_whole_digits > normalized_max_whole_digits)
170+
{
171+
return Err(ValError::new(
172+
ErrorType::DecimalWholeDigits {
173+
whole_digits: max_whole_digits,
174+
context: None,
175+
},
176+
input,
177+
));
178+
}
179+
}
157180
}
158181
}
159-
}
182+
};
160183
}
161184
}
162185

tests/validators/test_decimal.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,3 +437,31 @@ def test_non_finite_constrained_decimal_values(input_value, allow_inf_nan, expec
437437
def test_validate_scientific_notation_from_json(input_value, expected):
438438
v = SchemaValidator({'type': 'decimal'})
439439
assert v.validate_json(input_value) == expected
440+
441+
442+
def test_validate_max_digits_and_decimal_places() -> None:
443+
v = SchemaValidator({'type': 'decimal', 'max_digits': 5, 'decimal_places': 2})
444+
445+
# valid inputs
446+
assert v.validate_json('1.23') == Decimal('1.23')
447+
assert v.validate_json('123.45') == Decimal('123.45')
448+
assert v.validate_json('-123.45') == Decimal('-123.45')
449+
450+
# invalid inputs
451+
with pytest.raises(ValidationError):
452+
v.validate_json('1234.56') # too many digits
453+
with pytest.raises(ValidationError):
454+
v.validate_json('123.456') # too many decimal places
455+
with pytest.raises(ValidationError):
456+
v.validate_json('123456') # too many digits
457+
with pytest.raises(ValidationError):
458+
v.validate_json('abc') # not a valid decimal
459+
460+
461+
def test_validate_max_digits_and_decimal_places_edge_case() -> None:
462+
v = SchemaValidator({'type': 'decimal', 'max_digits': 34, 'decimal_places': 18})
463+
464+
# valid inputs
465+
assert v.validate_python(Decimal('9999999999999999.999999999999999999')) == Decimal(
466+
'9999999999999999.999999999999999999'
467+
)

0 commit comments

Comments
 (0)