diff --git a/Lib/_pydecimal.py b/Lib/_pydecimal.py index 9b8e42a2342536..9fdecf68e89c5b 100644 --- a/Lib/_pydecimal.py +++ b/Lib/_pydecimal.py @@ -442,6 +442,95 @@ def IEEEContext(bits, /): ##### Decimal class ####################################################### +# Observation: For all q >= 0 and a >= 1, q < 10**a iff len(str(q)) <= a. +# +# The constants below are used to speed-up "q < 10 ** a" checks to avoid +# computing len(str(q)) as much as possible. Those speed-ups are based on +# the following claims. +# +# See https://github.com/python/cpython/issues/140036 for details. + +_LOG_10_BASE_2_LO = float.fromhex('0x1.a934f0979a371p+1') +assert pow(2, _LOG_10_BASE_2_LO) < 10 + +_LOG_10_BASE_2_HI = float.fromhex('0x1.a934f0979a372p+1') +assert pow(2, _LOG_10_BASE_2_HI) > 10 + + +def _tento(n): + """Compute 10 ** n with 1 base-5 exponentiation and 1 bit-shift.""" + return (5 ** n) << n + + +def _is_less_than_pow10a_use_str(q, a): + """Try to efficiently check len(str(q)) <= a, or equivalently q < 10**a. + + If the comparison cannot be obtained from q.bit_length(), + then str(q) is explicitly computed and may raise ValueError. + + Return (len(str(q)) <= a, None) or (len(str(q)) <= a, str(q)). + """ + if q.bit_length() < a * _LOG_10_BASE_2_LO: + # Claim: If 0 < z <= log2(10) and q.bit_length() < a*z, then q < 10**a. + # Proof: By contradiction, q >= 10**a. By definition, + # log2(q) >= a*log2(10) >= a*z > q.bit_length(). + # In particular, q > 2**q.bit_length(), which is impossible. + + # assert q < 10 ** a + return True, None + elif q.bit_length() >= 1 + a * _LOG_10_BASE_2_HI: + # Claim: If z > log2(10) and q.bit_length() >= 1 + a*z, then q > 10**a. + # Proof: Since q >= 2**(q.bit_length()-1), we have + # q >= 2**(q.bit_length()-1) >= 2**(a*z) > 2**(a*log2(10)) = 10**a. + + # assert q > 10 ** a + return False, None + # Handle cases that fail due to floating point precision loss + # when computing _LOG_10_BASE_2_LO and _LOG_10_BASE_2_HI, or + # that cannot be distinguished with (q.bit_length(), a) only. + # + # For instance, (q1, a) = (95, 2) and (q2, a) = (105, 2) produce + # different results but q1.bit_length() == q2.bit_length() == 7. + str_q = str(q) # can raise a ValueError + is_valid = len(str_q) <= a + return is_valid, str_q + + +def _is_less_than_pow10a(q, a, *, exact=True, ulp_order=20): + """Check that len(str(q)) <= a without computing str(q). + + When *exact* is false, computing len(str(q)) is replaced by f(q): + + f(q) = floor(log10(q) + ulp(log10(q)) * ulp_order + 1.0) + + Most of the time, f(q) = len(str(q)) but in some cases, it may + happen that f(q) > len(str(q)). + + When *exact* is true, computing len(str(q)) requires one bigint + exponentiation that only depends on q. + """ + if q < 10: + return a >= 1 + + z = _math.log10(q) + t = _math.ulp(z) * ulp_order + + if exact: + intlo = int(z - t) + inthi = int(z + t) + diff = inthi - intlo + assert diff in (0, 1) + if diff == 1: + lo = _tento(inthi) # may be slow + if q < lo: + inthi -= 1 + assert q >= (lo // 10) + ndigits = inthi + 1 + else: + ndigits = int(z + t + 1.0) + return ndigits <= a + + # Do not subclass Decimal from numbers.Real and do not register it as such # (because Decimals are not interoperable with floats). See the notes in # numbers.py for more detail. @@ -1355,8 +1444,13 @@ def _divide(self, other, context): else: op2.int *= 10**(op2.exp - op1.exp) q, r = divmod(op1.int, op2.int) - if q < 10**context.prec: - return (_dec_from_triple(sign, str(q), 0), + is_valid, str_q = _is_less_than_pow10a_use_str(q, context.prec) + if is_valid: + if str_q is None: + str_q = str(q) + # assert q < 10 ** context.prec + # assert len(str(q)) <= context.prec + return (_dec_from_triple(sign, str_q, 0), _dec_from_triple(self._sign, str(r), ideal_exp)) # Here the quotient is too large to be representable @@ -1511,7 +1605,9 @@ def remainder_near(self, other, context=None): r -= op2.int q += 1 - if q >= 10**context.prec: + if not _is_less_than_pow10a(q, context.prec): + # assert q >= 10 ** context.prec + # assert len(str(q)) > context.prec return context._raise_error(DivisionImpossible) # result has same sign as self unless r is negative diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index 08a8f4c3b36bd6..7eeb77dec78e0c 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -24,6 +24,7 @@ with the corresponding argument. """ +import contextlib import logging import math import os, sys @@ -2611,6 +2612,37 @@ def tearDown(self): sys.set_int_max_str_digits(self._previous_int_limit) super().tearDown() + def test_helper__is_less_than_pow10a_use_str_slow_path(self): + # Test the "slow" path of _is_less_than_pow10a_use_str(). + a, b = 2, 7 + + # Choose q1, q2 such that len(str(q1)) <= a < len(str(q2)) + # and q1.bit_length() == q2.bit_length() == b to check that + # we cover the "slow" path correctly even for small values. + q1, q2 = 95, 105 + b1, b2 = q1.bit_length(), q2.bit_length() + + self.assertEqual(b1, b) + self.assertEqual(b2, b) + + # ensure that the first "fast" check doesn't hold + self.assertGreaterEqual(b, a * self.decimal._LOG_10_BASE_2_LO) + # ensure that the second "fast" check doesn't hold + self.assertLess(b, 1 + a * self.decimal._LOG_10_BASE_2_HI) + + cond_q1, str_q1 = self.decimal._is_less_than_pow10a_use_str(q1, a) + self.assertTrue(cond_q1) + self.assertIsNotNone(str_q1) + + cond_q2, str_q2 = self.decimal._is_less_than_pow10a_use_str(q2, a) + self.assertFalse(cond_q2) + self.assertIsNotNone(str_q2) + + def test_helper__is_less_than_pow10a(self): + # TODO(picnixz): find a simple test case with custom ulp_order. + pass + + class PythonAPItests: def test_abc(self): @@ -4493,6 +4525,15 @@ def test_decimal_attributes(self): class Coverage: + @contextlib.contextmanager + def unbound_context(self, prec=None, Emax=None, Emin=None): + with self.decimal.localcontext() as c: + c.prec = self.decimal.MAX_PREC if prec is None else prec + c.Emax = self.decimal.MAX_EMAX if Emax is None else Emax + c.Emin = self.decimal.MIN_EMIN if Emin is None else Emin + c.traps[self.decimal.Inexact] = 1 + yield c + def test_adjusted(self): Decimal = self.decimal.Decimal @@ -4660,6 +4701,22 @@ def test_divmod(self): self.assertTrue(c.flags[InvalidOperation] and c.flags[DivisionByZero]) + def test_divide_unbound_context(self): + with self.unbound_context() as c: + x = self.decimal.Decimal('1') + y = x // 1 # should be fast + + def test_remainder_near(self): + L = 1000 + limit = sys.get_int_max_str_digits() + sys.set_int_max_str_digits(L) + self.addCleanup(sys.set_int_max_str_digits, limit) + + with self.unbound_context(prec=2 * L) as c: + self.assertEqual(c.prec, 2 * L) + x = self.decimal.Decimal(f'1e{L}') + y = x.remainder_near(1) # must not raise a ValueError + def test_power(self): Decimal = self.decimal.Decimal localcontext = self.decimal.localcontext diff --git a/Misc/NEWS.d/next/Library/2025-10-13-15-43-09.gh-issue-140036.b_59uN.rst b/Misc/NEWS.d/next/Library/2025-10-13-15-43-09.gh-issue-140036.b_59uN.rst new file mode 100644 index 00000000000000..4bd92b4d735225 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-10-13-15-43-09.gh-issue-140036.b_59uN.rst @@ -0,0 +1,2 @@ +Avoid hanging in floor division of pure Python :class:`decimal.Decimal` +instances when the context precision is very large. Patch by Bénédikt Tran.