Skip to content

Commit 329c31a

Browse files
committed
refine q < 10**context.prec checks
1 parent ee617bd commit 329c31a

File tree

1 file changed

+62
-8
lines changed

1 file changed

+62
-8
lines changed

Lib/_pydecimal.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,6 @@
9292

9393
MIN_ETINY = MIN_EMIN - (MAX_PREC-1)
9494

95-
_LOG_10_BASE_2 = float.fromhex('0x1.a934f0979a371p+1') # log2(10)
96-
9795
# Errors
9896

9997
class DecimalException(ArithmeticError):
@@ -444,6 +442,27 @@ def IEEEContext(bits, /):
444442

445443
##### Decimal class #######################################################
446444

445+
# Observation: For all q >= 0 and a >= 1, q < 10**a iff len(str(q)) <= a.
446+
#
447+
# The constants below are used to speed-up "q < 10 ** a" checks to avoid
448+
# computing len(str(q)) as much as possible. Those speed-ups are based on
449+
# the following claims.
450+
#
451+
# See https://github.com/python/cpython/issues/140036 for details.
452+
453+
# Claim: If 0 < z <= log2(10) and q.bit_length() < a*z, then q < 10**a.
454+
# Proof: By contradiction, q >= 10**a. By definition,
455+
# log2(q) >= a*log2(10) >= a*z > q.bit_length().
456+
# In particular, q > 2**q.bit_length(), which is impossible.
457+
_LOG_10_BASE_2_LO = float.fromhex('0x1.a934f0979a371p+1')
458+
assert pow(2, _LOG_10_BASE_2_LO) < 10
459+
460+
# Claim: If z > log2(10) and q.bit_length() >= 1 + a*z, then q > 10**a.
461+
# Proof: Since q >= 2**(q.bit_length()-1), we have
462+
# q >= 2**(q.bit_length()-1) >= 2**(a*z) > 2**(a*log2(10)) = 10**a.
463+
_LOG_10_BASE_2_HI = float.fromhex('0x1.a934f0979a372p+1')
464+
assert pow(2, _LOG_10_BASE_2_HI) > 10
465+
447466
# Do not subclass Decimal from numbers.Real and do not register it as such
448467
# (because Decimals are not interoperable with floats). See the notes in
449468
# numbers.py for more detail.
@@ -1357,11 +1376,27 @@ def _divide(self, other, context):
13571376
else:
13581377
op2.int *= 10**(op2.exp - op1.exp)
13591378
q, r = divmod(op1.int, op2.int)
1360-
if q.bit_length() < 1 + context.prec * _LOG_10_BASE_2:
1361-
# ensure that the previous check was sufficient
1362-
if len(str_q := str(q)) <= context.prec:
1363-
return (_dec_from_triple(sign, str_q, 0),
1364-
_dec_from_triple(self._sign, str(r), ideal_exp))
1379+
# See notes for _LOG_10_BASE_2_LO and _LOG_10_BASE_2_HI.
1380+
str_q = None # to cache str(q) when possible
1381+
if q.bit_length() < context.prec * _LOG_10_BASE_2_LO:
1382+
# assert q < 10 ** context.prec
1383+
is_valid = True
1384+
elif q.bit_length() >= 1 + context.prec * _LOG_10_BASE_2_HI:
1385+
# assert q > 10 ** context.prec
1386+
is_valid = False
1387+
else:
1388+
# Handles other cases due to floating point precision loss
1389+
# when computing _LOG_10_BASE_2_LO and _LOG_10_BASE_2_HI.
1390+
# Computation of str(q) may fail!
1391+
str_q = str(q) # we need to compute this in case of success
1392+
is_valid = len(str_q) <= context.prec
1393+
if is_valid:
1394+
if str_q is None:
1395+
str_q = str(q)
1396+
# assert q < 10 ** context.prec
1397+
# assert len(str(q)) <= context.prec
1398+
return (_dec_from_triple(sign, str_q, 0),
1399+
_dec_from_triple(self._sign, str(r), ideal_exp))
13651400

13661401
# Here the quotient is too large to be representable
13671402
ans = context._raise_error(DivisionImpossible,
@@ -1515,7 +1550,26 @@ def remainder_near(self, other, context=None):
15151550
r -= op2.int
15161551
q += 1
15171552

1518-
if q >= 10**context.prec:
1553+
# See notes for _LOG_10_BASE_2_LO and _LOG_10_BASE_2_HI.
1554+
if q.bit_length() < context.prec * _LOG_10_BASE_2_LO:
1555+
# assert q < 10 ** context.prec
1556+
is_valid = True
1557+
elif q.bit_length() >= 1 + context.prec * _LOG_10_BASE_2_HI:
1558+
# assert q > 10 ** context.prec
1559+
is_valid = False
1560+
else:
1561+
# Handles other cases due to floating point precision loss
1562+
# when computing _LOG_10_BASE_2_LO and _LOG_10_BASE_2_HI.
1563+
# Computation of str(q) or 10 ** context.prec may be slow!
1564+
try:
1565+
str_q = str(q)
1566+
except ValueError:
1567+
is_valid = q < 10 ** context.prec
1568+
else:
1569+
is_valid = len(str_q) <= context.prec
1570+
if not is_valid:
1571+
# assert q >= 10 ** context.prec
1572+
# assert len(str(q)) > context.prec
15191573
return context._raise_error(DivisionImpossible)
15201574

15211575
# result has same sign as self unless r is negative

0 commit comments

Comments
 (0)