Skip to content

Commit 0a5a115

Browse files
committed
optimize list.__eq__ with a jit driver
fixes pypygh-5300
1 parent 1027296 commit 0a5a115

File tree

3 files changed

+112
-16
lines changed

3 files changed

+112
-16
lines changed

pypy/module/pypyjit/test_pypy_c/test_containers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,3 +379,28 @@ def main():
379379
opnames = log.opnames(ops)
380380
# only a call to ll_kvi, not to get_strategy_from_list_objects
381381
assert opnames.count("call_r") == 1
382+
383+
def test_list_eq(self):
384+
def main():
385+
l1 = [i * 3 + 1 for i in range(10000)]
386+
l2 = [i * 3 + 1 for i in range(10000)]
387+
return l1 == l2
388+
log = self.run(main, [])
389+
loop = log._filter(log.loops[-1], is_entry_bridge=False)
390+
loop.match("""
391+
i22 = uint_ge(i19, i6)
392+
guard_false(i22, descr=...)
393+
i23 = getarrayitem_gc_i(p8, i19, descr=...)
394+
i24 = uint_ge(i19, i13)
395+
guard_false(i24, descr=...)
396+
i25 = getarrayitem_gc_i(p15, i19, descr=...)
397+
i26 = int_eq(i25, i23)
398+
guard_true(i26, descr=...)
399+
i28 = int_add(i19, 1)
400+
i29 = int_lt(i28, i6)
401+
guard_true(i29, descr=...)
402+
i30 = int_lt(i28, i13)
403+
guard_true(i30, descr=...)
404+
i31 = arraylen_gc(p8, descr=...)
405+
i32 = arraylen_gc(p15, descr=...)
406+
jump(i28, p1, p2, p5, i6, p8, p12, i13, p15, descr=...)""")

pypy/objspace/std/listobject.py

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,68 @@ def list_unroll_condition(w_list1, space, w_list2):
226226
return (w_list1._unrolling_heuristic() or w_list2._unrolling_heuristic())
227227

228228

229+
def _get_printable_location(strategy_type1, strategy_type2, typ):
230+
return 'list.eq [%s, %s, %s]' % (
231+
strategy_type1,
232+
strategy_type2,
233+
typ)
234+
235+
236+
listeq_jitdriver = jit.JitDriver(
237+
name='list.eq',
238+
greens=['strategy_type1', 'strategy_type2', 'typ'],
239+
reds='auto',
240+
get_printable_location=_get_printable_location)
241+
242+
def _make_list_eq(withjitdriver):
243+
def list_eq(w_list, space, w_other):
244+
# needs to be safe against eq_w() mutating the w_lists behind our back
245+
length = w_list.length()
246+
if length != w_other.length():
247+
return space.w_False
248+
if not length:
249+
return space.w_True
250+
if withjitdriver:
251+
typ = type(w_list.getitem(0))
252+
253+
i = 0
254+
while True:
255+
if withjitdriver:
256+
listeq_jitdriver.jit_merge_point(
257+
strategy_type1=type(w_list.strategy),
258+
strategy_type2=type(w_other.strategy),
259+
typ=typ,
260+
)
261+
try:
262+
w_item1 = w_list.getitem(i)
263+
w_item2 = w_other.getitem(i)
264+
except IndexError:
265+
break
266+
if not space.eq_w(w_item1, w_item2):
267+
return space.w_False
268+
269+
i += 1
270+
271+
# if the list length is different now, the list was modified by eq_w with
272+
l1 = w_list.length()
273+
l2 = w_other.length()
274+
if l1 != l2:
275+
return space.w_False
276+
return space.w_True
277+
return list_eq
278+
_list_eq_withjitdriver = _make_list_eq(True)
279+
_list_eq_unroll = jit.unroll_safe(_make_list_eq(False))
280+
281+
282+
def list_eq(w_list, space, w_other):
283+
# we can't use look_inside_iff because the jitdriver will be found in two
284+
# different graphs then
285+
if jit.we_are_jitted() and list_unroll_condition(w_list, space, w_other):
286+
return _list_eq_unroll(w_list, space, w_other)
287+
else:
288+
return _list_eq_withjitdriver(w_list, space, w_other)
289+
290+
229291
class W_ListObject(W_Root):
230292
strategy = None
231293

@@ -507,22 +569,7 @@ def descr_repr(self, space):
507569
def descr_eq(self, space, w_other):
508570
if not isinstance(w_other, W_ListObject):
509571
return space.w_NotImplemented
510-
return self._descr_eq(space, w_other)
511-
512-
@jit.look_inside_iff(list_unroll_condition)
513-
def _descr_eq(self, space, w_other):
514-
# needs to be safe against eq_w() mutating the w_lists behind our back
515-
if self.length() != w_other.length():
516-
return space.w_False
517-
518-
# XXX in theory, this can be implemented more efficiently as well.
519-
# let's not care for now
520-
i = 0
521-
while i < self.length() and i < w_other.length():
522-
if not space.eq_w(self.getitem(i), w_other.getitem(i)):
523-
return space.w_False
524-
i += 1
525-
return space.w_True
572+
return list_eq(self, space, w_other)
526573

527574
descr_ne = negate(descr_eq)
528575

pypy/objspace/std/test/test_listobject.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,6 +1544,30 @@ def __repr__(self):
15441544
l.append(191)
15451545
assert repr(l) == '[1, ouchie]'
15461546

1547+
def test_mutate_while_eq(self):
1548+
class Mean(object):
1549+
def __init__(self, i):
1550+
self.i = i
1551+
def __eq__(self, other):
1552+
if self.i == 9:
1553+
del l1[self.i - 1]
1554+
return True
1555+
l1 = [Mean(i) for i in range(10)]
1556+
l2 = [Mean(i) for i in range(10)]
1557+
assert l1 != l2
1558+
1559+
class Mean(object):
1560+
def __init__(self, i):
1561+
self.i = i
1562+
def __eq__(self, other):
1563+
if self.i == 9:
1564+
del l1[self.i - 1]
1565+
del l2[self.i - 1]
1566+
return True
1567+
l1 = [Mean(i) for i in range(10)]
1568+
l2 = [Mean(i) for i in range(10)]
1569+
assert l1 == l2
1570+
15471571
def test___getslice__(self):
15481572
l = [1,2,3,4]
15491573
res = l.__getslice__(0, 2)

0 commit comments

Comments
 (0)