Skip to content

Commit 4951511

Browse files
committed
refine q < 10**context.prec checks
1 parent ee617bd commit 4951511

File tree

1 file changed

+66
-8
lines changed

1 file changed

+66
-8
lines changed

Lib/_pydecimal.py

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,6 @@
9292

9393
MIN_ETINY = MIN_EMIN - (MAX_PREC-1)
9494

95-
_LOG_10_BASE_2 = float.fromhex('0x1.a934f0979a371p+1') # log2(10)
96-
9795
# Errors
9896

9997
class DecimalException(ArithmeticError):
@@ -444,6 +442,27 @@ def IEEEContext(bits, /):
444442

445443
##### Decimal class #######################################################
446444

445+
# Observation: For all q >= 0 and a >= 1, q < 10**a iff len(str(q)) <= a.
446+
#
447+
# The constants below are used to speed-up "q < 10 ** a" checks to avoid
448+
# computing len(str(q)) as much as possible. Those speed-ups are based on
449+
# the following claims.
450+
#
451+
# See https://github.com/python/cpython/issues/140036 for details.
452+
453+
# Claim: If 0 < z <= log2(10) and q.bit_length() < a*z, then q < 10**a.
454+
# Proof: By contradiction, q >= 10**a. By definition,
455+
# log2(q) >= a*log2(10) >= a*z > q.bit_length().
456+
# In particular, q > 2**q.bit_length(), which is impossible.
457+
_LOG_10_BASE_2_LO = float.fromhex('0x1.a934f0979a371p+1')
458+
assert pow(2, _LOG_10_BASE_2_LO) < 10
459+
460+
# Claim: If z > log2(10) and q.bit_length() >= 1 + a*z, then q > 10**a.
461+
# Proof: Since q >= 2**(q.bit_length()-1), we have
462+
# q >= 2**(q.bit_length()-1) >= 2**(a*z) > 2**(a*log2(10)) = 10**a.
463+
_LOG_10_BASE_2_HI = float.fromhex('0x1.a934f0979a372p+1')
464+
assert pow(2, _LOG_10_BASE_2_HI) > 10
465+
447466
# Do not subclass Decimal from numbers.Real and do not register it as such
448467
# (because Decimals are not interoperable with floats). See the notes in
449468
# numbers.py for more detail.
@@ -1357,11 +1376,31 @@ def _divide(self, other, context):
13571376
else:
13581377
op2.int *= 10**(op2.exp - op1.exp)
13591378
q, r = divmod(op1.int, op2.int)
1360-
if q.bit_length() < 1 + context.prec * _LOG_10_BASE_2:
1361-
# ensure that the previous check was sufficient
1362-
if len(str_q := str(q)) <= context.prec:
1363-
return (_dec_from_triple(sign, str_q, 0),
1364-
_dec_from_triple(self._sign, str(r), ideal_exp))
1379+
# See notes for _LOG_10_BASE_2_LO and _LOG_10_BASE_2_HI.
1380+
str_q = None # to cache str(q) when possible
1381+
if q.bit_length() < context.prec * _LOG_10_BASE_2_LO:
1382+
# assert q < 10 ** context.prec
1383+
is_valid = True
1384+
elif q.bit_length() >= 1 + context.prec * _LOG_10_BASE_2_HI:
1385+
# assert q > 10 ** context.prec
1386+
is_valid = False
1387+
else:
1388+
# Handles other cases due to floating point precision loss
1389+
# when computing _LOG_10_BASE_2_LO and _LOG_10_BASE_2_HI.
1390+
# Computation of str(q) or 10 ** context.prec may be slow!
1391+
try:
1392+
str_q = str(q)
1393+
except ValueError:
1394+
is_valid = q < 10 ** context.prec
1395+
else:
1396+
is_valid = len(str_q) <= context.prec
1397+
if is_valid:
1398+
if str_q is None:
1399+
str_q = str(q)
1400+
# assert q < 10 ** context.prec
1401+
# assert len(str(q)) <= context.prec
1402+
return (_dec_from_triple(sign, str_q, 0),
1403+
_dec_from_triple(self._sign, str(r), ideal_exp))
13651404

13661405
# Here the quotient is too large to be representable
13671406
ans = context._raise_error(DivisionImpossible,
@@ -1515,7 +1554,26 @@ def remainder_near(self, other, context=None):
15151554
r -= op2.int
15161555
q += 1
15171556

1518-
if q >= 10**context.prec:
1557+
# See notes for _LOG_10_BASE_2_LO and _LOG_10_BASE_2_HI.
1558+
if q.bit_length() < context.prec * _LOG_10_BASE_2_LO:
1559+
# assert q < 10 ** context.prec
1560+
is_valid = True
1561+
elif q.bit_length() >= 1 + context.prec * _LOG_10_BASE_2_HI:
1562+
# assert q > 10 ** context.prec
1563+
is_valid = False
1564+
else:
1565+
# Handles other cases due to floating point precision loss
1566+
# when computing _LOG_10_BASE_2_LO and _LOG_10_BASE_2_HI.
1567+
# Computation of str(q) or 10 ** context.prec may be slow!
1568+
try:
1569+
str_q = str(q)
1570+
except ValueError:
1571+
is_valid = q < 10 ** context.prec
1572+
else:
1573+
is_valid = len(str_q) <= context.prec
1574+
if not is_valid:
1575+
# assert q >= 10 ** context.prec
1576+
# assert len(str(q)) > context.prec
15191577
return context._raise_error(DivisionImpossible)
15201578

15211579
# result has same sign as self unless r is negative

0 commit comments

Comments
 (0)