Skip to content
99 changes: 96 additions & 3 deletions Lib/_pydecimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,92 @@ def IEEEContext(bits, /):

##### Decimal class #######################################################

# Observation: For all q >= 0 and a >= 1, q < 10**a iff len(str(q)) <= a.
#
# The constants below are used to speed-up "q < 10 ** a" checks to avoid
# computing len(str(q)) as much as possible. Those speed-ups are based on
# the following claims.
#
# See https://github.com/python/cpython/issues/140036 for details.

_LOG_10_BASE_2_LO = float.fromhex('0x1.a934f0979a371p+1')
assert pow(2, _LOG_10_BASE_2_LO) < 10

_LOG_10_BASE_2_HI = float.fromhex('0x1.a934f0979a372p+1')
assert pow(2, _LOG_10_BASE_2_HI) > 10


def _tento(n):
"""Compute 10 ** n with 1 base-5 exponentiation and 1 bit-shift."""
return (5 ** n) << n


def _is_leq_than_pow10a_use_str(q, a):
"""Try to efficiently check len(str(q)) <= a, or equivalently q < 10**a.

If it is not possible to efficiently compute len(str(q)),
this explicitly compute str(q) instead.

Return (len(str(q)) <= a, None) or (len(str(q)) <= a, str(q)).
"""
if q.bit_length() < a * _LOG_10_BASE_2_LO:
# Claim: If 0 < z <= log2(10) and q.bit_length() < a*z, then q < 10**a.
# Proof: By contradiction, q >= 10**a. By definition,
# log2(q) >= a*log2(10) >= a*z > q.bit_length().
# In particular, q > 2**q.bit_length(), which is impossible.

# assert q < 10 ** context.prec
return True, None
elif q.bit_length() >= 1 + a * _LOG_10_BASE_2_HI:
# Claim: If z > log2(10) and q.bit_length() >= 1 + a*z, then q > 10**a.
# Proof: Since q >= 2**(q.bit_length()-1), we have
# q >= 2**(q.bit_length()-1) >= 2**(a*z) > 2**(a*log2(10)) = 10**a.

# assert q > 10 ** context.prec
return False, None
# Handles other cases due to floating point precision loss
# when computing _LOG_10_BASE_2_LO and _LOG_10_BASE_2_HI.
str_q = str(q) # can raise a ValueError
is_valid = len(str_q) <= a
return is_valid, str_q


def _is_leq_than_pow10a(q, a, *, exact=True, ulp_order=20):
"""Check that len(str(q)) <= a without computing str(q).

When *exact* is false, computing len(str(q)) is replaced by f(q):

f(q) = floor(log10(q) + ulp(log10(q)) * ulp_order + 1.0)

Most of the time, f(q) = len(str(q)) but in some cases, it may
happen that f(q) > len(str(q)).

When *exact* is true, computing len(str(q)) requires one bigint
exponentiation that only depends on q.
"""

if q < 10:
return a >= 1

z = _math.log10(q)
t = _math.ulp(z) * ulp_order

if exact:
intlo = int(z - t)
inthi = int(z + t)
diff = inthi - intlo
assert diff in (0, 1)
if diff == 1:
lo = _tento(inthi) # may be slow
if q < lo:
inthi -= 1
assert q >= (lo // 10)
ndigits = inthi + 1
else:
ndigits = int(z + t + 1.0)
return ndigits <= a


# Do not subclass Decimal from numbers.Real and do not register it as such
# (because Decimals are not interoperable with floats). See the notes in
# numbers.py for more detail.
Expand Down Expand Up @@ -1355,8 +1441,13 @@ def _divide(self, other, context):
else:
op2.int *= 10**(op2.exp - op1.exp)
q, r = divmod(op1.int, op2.int)
if q < 10**context.prec:
return (_dec_from_triple(sign, str(q), 0),
is_valid, str_q = _is_leq_than_pow10a_use_str(q, context.prec)
if is_valid:
if str_q is None:
str_q = str(q)
# assert q < 10 ** context.prec
# assert len(str(q)) <= context.prec
return (_dec_from_triple(sign, str_q, 0),
_dec_from_triple(self._sign, str(r), ideal_exp))

# Here the quotient is too large to be representable
Expand Down Expand Up @@ -1511,7 +1602,9 @@ def remainder_near(self, other, context=None):
r -= op2.int
q += 1

if q >= 10**context.prec:
if not _is_leq_than_pow10a(q, context.prec):
# assert q >= 10 ** context.prec
# assert len(str(q)) > context.prec
return context._raise_error(DivisionImpossible)

# result has same sign as self unless r is negative
Expand Down
26 changes: 26 additions & 0 deletions Lib/test/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
with the corresponding argument.
"""

import contextlib
import logging
import math
import os, sys
Expand Down Expand Up @@ -4493,6 +4494,15 @@ def test_decimal_attributes(self):

class Coverage:

@contextlib.contextmanager
def unbound_context(self, prec=None, Emax=None, Emin=None):
with self.decimal.localcontext() as c:
c.prec = self.decimal.MAX_PREC if prec is None else prec
c.Emax = self.decimal.MAX_EMAX if Emax is None else Emax
c.Emin = self.decimal.MIN_EMIN if Emin is None else Emin
c.traps[self.decimal.Inexact] = 1
yield c

def test_adjusted(self):
Decimal = self.decimal.Decimal

Expand Down Expand Up @@ -4660,6 +4670,22 @@ def test_divmod(self):
self.assertTrue(c.flags[InvalidOperation] and
c.flags[DivisionByZero])

def test_divide_unbound_context(self):
with self.unbound_context() as c:
x = self.decimal.Decimal('1')
y = x // 1 # should be fast

def test_remainder_near(self):
L = 1000
limit = sys.get_int_max_str_digits()
sys.set_int_max_str_digits(L)
self.addCleanup(sys.set_int_max_str_digits, limit)

with self.unbound_context(prec=2 * L) as c:
self.assertEqual(c.prec, 2 * L)
x = self.decimal.Decimal(f'1e{L}')
y = x.remainder_near(1) # must not raise a ValueError

def test_power(self):
Decimal = self.decimal.Decimal
localcontext = self.decimal.localcontext
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Avoid hanging in floor division of pure Python :class:`decimal.Decimal`
instances when the context precision is very large. Patch by Bénédikt Tran.
Loading