Skip to content

Commit 2f0a338

Browse files
gh-118610: Centralize power caching in _pylong.py (#118611)
A new `compute_powers()` function computes all and only the powers of the base the various base-conversion functions need, as efficiently as reasonably possible (turns out that invoking `**`is needed at most once). This typically gives a few % speedup, but the primary point is to simplify the base-conversion functions, which no longer need their own, ad hoc, and less efficient power-caching schemes. Co-authored-by: Serhiy Storchaka <[email protected]>
1 parent 2a85bed commit 2f0a338

File tree

2 files changed

+113
-67
lines changed

2 files changed

+113
-67
lines changed

Lib/_pylong.py

Lines changed: 101 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,86 @@
1919
except ImportError:
2020
_decimal = None
2121

22+
# A number of functions have this form, where `w` is a desired number of
23+
# digits in base `base`:
24+
#
25+
# def inner(...w...):
26+
# if w <= LIMIT:
27+
# return something
28+
# lo = w >> 1
29+
# hi = w - lo
30+
# something involving base**lo, inner(...lo...), j, and inner(...hi...)
31+
# figure out largest w needed
32+
# result = inner(w)
33+
#
34+
# They all had some on-the-fly scheme to cache `base**lo` results for reuse.
35+
# Power is costly.
36+
#
37+
# This routine aims to compute all amd only the needed powers in advance, as
38+
# efficiently as reasonably possible. This isn't trivial, and all the
39+
# on-the-fly methods did needless work in many cases. The driving code above
40+
# changes to:
41+
#
42+
# figure out largest w needed
43+
# mycache = compute_powers(w, base, LIMIT)
44+
# result = inner(w)
45+
#
46+
# and `mycache[lo]` replaces `base**lo` in the inner function.
47+
#
48+
# While this does give minor speedups (a few percent at best), the primary
49+
# intent is to simplify the functions using this, by eliminating the need for
50+
# them to craft their own ad-hoc caching schemes.
51+
def compute_powers(w, base, more_than, show=False):
52+
seen = set()
53+
need = set()
54+
ws = {w}
55+
while ws:
56+
w = ws.pop() # any element is fine to use next
57+
if w in seen or w <= more_than:
58+
continue
59+
seen.add(w)
60+
lo = w >> 1
61+
# only _need_ lo here; some other path may, or may not, need hi
62+
need.add(lo)
63+
ws.add(lo)
64+
if w & 1:
65+
ws.add(lo + 1)
66+
67+
d = {}
68+
if not need:
69+
return d
70+
it = iter(sorted(need))
71+
first = next(it)
72+
if show:
73+
print("pow at", first)
74+
d[first] = base ** first
75+
for this in it:
76+
if this - 1 in d:
77+
if show:
78+
print("* base at", this)
79+
d[this] = d[this - 1] * base # cheap
80+
else:
81+
lo = this >> 1
82+
hi = this - lo
83+
assert lo in d
84+
if show:
85+
print("square at", this)
86+
# Multiplying a bigint by itself (same object!) is about twice
87+
# as fast in CPython.
88+
sq = d[lo] * d[lo]
89+
if hi != lo:
90+
assert hi == lo + 1
91+
if show:
92+
print(" and * base")
93+
sq *= base
94+
d[this] = sq
95+
return d
96+
97+
_unbounded_dec_context = decimal.getcontext().copy()
98+
_unbounded_dec_context.prec = decimal.MAX_PREC
99+
_unbounded_dec_context.Emax = decimal.MAX_EMAX
100+
_unbounded_dec_context.Emin = decimal.MIN_EMIN
101+
_unbounded_dec_context.traps[decimal.Inexact] = 1 # sanity check
22102

23103
def int_to_decimal(n):
24104
"""Asymptotically fast conversion of an 'int' to Decimal."""
@@ -33,57 +113,32 @@ def int_to_decimal(n):
33113
# "clever" recursive way. If we want a string representation, we
34114
# apply str to _that_.
35115

36-
D = decimal.Decimal
37-
D2 = D(2)
38-
39-
BITLIM = 128
40-
41-
mem = {}
42-
43-
def w2pow(w):
44-
"""Return D(2)**w and store the result. Also possibly save some
45-
intermediate results. In context, these are likely to be reused
46-
across various levels of the conversion to Decimal."""
47-
if (result := mem.get(w)) is None:
48-
if w <= BITLIM:
49-
result = D2**w
50-
elif w - 1 in mem:
51-
result = (t := mem[w - 1]) + t
52-
else:
53-
w2 = w >> 1
54-
# If w happens to be odd, w-w2 is one larger then w2
55-
# now. Recurse on the smaller first (w2), so that it's
56-
# in the cache and the larger (w-w2) can be handled by
57-
# the cheaper `w-1 in mem` branch instead.
58-
result = w2pow(w2) * w2pow(w - w2)
59-
mem[w] = result
60-
return result
116+
from decimal import Decimal as D
117+
BITLIM = 200
61118

119+
# Don't bother caching the "lo" mask in this; the time to compute it is
120+
# tiny compared to the multiply.
62121
def inner(n, w):
63122
if w <= BITLIM:
64123
return D(n)
65124
w2 = w >> 1
66125
hi = n >> w2
67-
lo = n - (hi << w2)
68-
return inner(lo, w2) + inner(hi, w - w2) * w2pow(w2)
69-
70-
with decimal.localcontext() as ctx:
71-
ctx.prec = decimal.MAX_PREC
72-
ctx.Emax = decimal.MAX_EMAX
73-
ctx.Emin = decimal.MIN_EMIN
74-
ctx.traps[decimal.Inexact] = 1
126+
lo = n & ((1 << w2) - 1)
127+
return inner(lo, w2) + inner(hi, w - w2) * w2pow[w2]
75128

129+
with decimal.localcontext(_unbounded_dec_context):
130+
nbits = n.bit_length()
131+
w2pow = compute_powers(nbits, D(2), BITLIM)
76132
if n < 0:
77133
negate = True
78134
n = -n
79135
else:
80136
negate = False
81-
result = inner(n, n.bit_length())
137+
result = inner(n, nbits)
82138
if negate:
83139
result = -result
84140
return result
85141

86-
87142
def int_to_decimal_string(n):
88143
"""Asymptotically fast conversion of an 'int' to a decimal string."""
89144
w = n.bit_length()
@@ -97,14 +152,13 @@ def int_to_decimal_string(n):
97152
# available. This algorithm is asymptotically worse than the algorithm
98153
# using the decimal module, but better than the quadratic time
99154
# implementation in longobject.c.
155+
156+
DIGLIM = 1000
100157
def inner(n, w):
101-
if w <= 1000:
158+
if w <= DIGLIM:
102159
return str(n)
103160
w2 = w >> 1
104-
d = pow10_cache.get(w2)
105-
if d is None:
106-
d = pow10_cache[w2] = 5**w2 << w2 # 10**i = (5*2)**i = 5**i * 2**i
107-
hi, lo = divmod(n, d)
161+
hi, lo = divmod(n, pow10[w2])
108162
return inner(hi, w - w2) + inner(lo, w2).zfill(w2)
109163

110164
# The estimation of the number of decimal digits.
@@ -115,7 +169,9 @@ def inner(n, w):
115169
# only if the number has way more than 10**15 digits, that exceeds
116170
# the 52-bit physical address limit in both Intel64 and AMD64.
117171
w = int(w * 0.3010299956639812 + 1) # log10(2)
118-
pow10_cache = {}
172+
pow10 = compute_powers(w, 5, DIGLIM)
173+
for k, v in pow10.items():
174+
pow10[k] = v << k # 5**k << k == 5**k * 2**k == 10**k
119175
if n < 0:
120176
n = -n
121177
sign = '-'
@@ -128,7 +184,6 @@ def inner(n, w):
128184
s = s.lstrip('0')
129185
return sign + s
130186

131-
132187
def _str_to_int_inner(s):
133188
"""Asymptotically fast conversion of a 'str' to an 'int'."""
134189

@@ -144,35 +199,15 @@ def _str_to_int_inner(s):
144199

145200
DIGLIM = 2048
146201

147-
mem = {}
148-
149-
def w5pow(w):
150-
"""Return 5**w and store the result.
151-
Also possibly save some intermediate results. In context, these
152-
are likely to be reused across various levels of the conversion
153-
to 'int'.
154-
"""
155-
if (result := mem.get(w)) is None:
156-
if w <= DIGLIM:
157-
result = 5**w
158-
elif w - 1 in mem:
159-
result = mem[w - 1] * 5
160-
else:
161-
w2 = w >> 1
162-
# If w happens to be odd, w-w2 is one larger then w2
163-
# now. Recurse on the smaller first (w2), so that it's
164-
# in the cache and the larger (w-w2) can be handled by
165-
# the cheaper `w-1 in mem` branch instead.
166-
result = w5pow(w2) * w5pow(w - w2)
167-
mem[w] = result
168-
return result
169-
170202
def inner(a, b):
171203
if b - a <= DIGLIM:
172204
return int(s[a:b])
173205
mid = (a + b + 1) >> 1
174-
return inner(mid, b) + ((inner(a, mid) * w5pow(b - mid)) << (b - mid))
206+
return (inner(mid, b)
207+
+ ((inner(a, mid) * w5pow[b - mid])
208+
<< (b - mid)))
175209

210+
w5pow = compute_powers(len(s), 5, DIGLIM)
176211
return inner(0, len(s))
177212

178213

@@ -186,7 +221,6 @@ def int_from_string(s):
186221
s = s.rstrip().replace('_', '')
187222
return _str_to_int_inner(s)
188223

189-
190224
def str_to_int(s):
191225
"""Asymptotically fast version of decimal string to 'int' conversion."""
192226
# FIXME: this doesn't support the full syntax that int() supports.

Lib/test/test_int.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,18 @@ def test_pylong_misbehavior_error_path_from_str(
906906
with self.assertRaises(RuntimeError):
907907
int(big_value)
908908

909+
def test_pylong_roundtrip(self):
910+
from random import randrange, getrandbits
911+
bits = 5000
912+
while bits <= 1_000_000:
913+
bits += randrange(-100, 101) # break bitlength patterns
914+
hibit = 1 << (bits - 1)
915+
n = hibit | getrandbits(bits - 1)
916+
assert n.bit_length() == bits
917+
sn = str(n)
918+
self.assertFalse(sn.startswith('0'))
919+
self.assertEqual(n, int(sn))
920+
bits <<= 1
909921

910922
if __name__ == "__main__":
911923
unittest.main()

0 commit comments

Comments
 (0)