Skip to content

Commit 7f60928

Browse files
committed
CVE-2022-48560 Cherry-pick c563f40
Cherry-pick c563f40 Fix up conflicts
1 parent c001067 commit 7f60928

File tree

3 files changed

+59
-7
lines changed

3 files changed

+59
-7
lines changed

Lib/test/test_heapq.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,37 @@ def test_heappop_mutating_heap(self):
396396
with self.assertRaises((IndexError, RuntimeError)):
397397
self.module.heappop(heap)
398398

399+
def test_comparison_operator_modifiying_heap(self):
400+
# See bpo-39421: Strong references need to be taken
401+
# when comparing objects as they can alter the heap
402+
class EvilClass(int):
403+
def __lt__(self, o):
404+
heap.clear()
405+
return NotImplemented
406+
407+
heap = []
408+
self.module.heappush(heap, EvilClass(0))
409+
self.assertRaises(IndexError, self.module.heappushpop, heap, 1)
410+
411+
def test_comparison_operator_modifiying_heap_two_heaps(self):
412+
413+
class h(int):
414+
def __lt__(self, o):
415+
list2.clear()
416+
return NotImplemented
417+
418+
class g(int):
419+
def __lt__(self, o):
420+
list1.clear()
421+
return NotImplemented
422+
423+
list1, list2 = [], []
424+
425+
self.module.heappush(list1, h(0))
426+
self.module.heappush(list2, g(0))
427+
428+
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
429+
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))
399430

400431
class TestErrorHandlingPython(TestErrorHandling):
401432
module = py_heapq
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fix possible crashes when operating with the functions in the :mod:`heapq`
2+
module and custom comparison operators.

Modules/_heapqmodule.c

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@ _siftdown(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos)
5252
while (pos > startpos) {
5353
parentpos = (pos - 1) >> 1;
5454
parent = PyList_GET_ITEM(heap, parentpos);
55+
Py_INCREF(newitem);
56+
Py_INCREF(parent);
5557
cmp = cmp_lt(newitem, parent);
58+
Py_DECREF(parent);
59+
Py_DECREF(newitem);
5660
if (cmp == -1)
5761
return -1;
5862
if (size != PyList_GET_SIZE(heap)) {
@@ -93,9 +97,13 @@ _siftup(PyListObject *heap, Py_ssize_t pos)
9397
childpos = 2*pos + 1; /* leftmost child position */
9498
rightpos = childpos + 1;
9599
if (rightpos < endpos) {
96-
cmp = cmp_lt(
97-
PyList_GET_ITEM(heap, childpos),
98-
PyList_GET_ITEM(heap, rightpos));
100+
PyObject* a = PyList_GET_ITEM(heap, childpos);
101+
PyObject* b = PyList_GET_ITEM(heap, rightpos);
102+
Py_INCREF(a);
103+
Py_INCREF(b);
104+
cmp = cmp_lt(a,b);
105+
Py_DECREF(a);
106+
Py_DECREF(b);
99107
if (cmp == -1)
100108
return -1;
101109
if (cmp == 0)
@@ -236,7 +244,10 @@ heappushpop(PyObject *self, PyObject *args)
236244
return item;
237245
}
238246

239-
cmp = cmp_lt(PyList_GET_ITEM(heap, 0), item);
247+
PyObject* top = PyList_GET_ITEM(heap, 0);
248+
Py_INCREF(top);
249+
cmp = cmp_lt(top, item);
250+
Py_DECREF(top);
240251
if (cmp == -1)
241252
return NULL;
242253
if (cmp == 0) {
@@ -395,7 +406,11 @@ _siftdownmax(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos)
395406
while (pos > startpos){
396407
parentpos = (pos - 1) >> 1;
397408
parent = PyList_GET_ITEM(heap, parentpos);
409+
Py_INCREF(parent);
410+
Py_INCREF(newitem);
398411
cmp = cmp_lt(parent, newitem);
412+
Py_DECREF(parent);
413+
Py_DECREF(newitem);
399414
if (cmp == -1) {
400415
Py_DECREF(newitem);
401416
return -1;
@@ -436,9 +451,13 @@ _siftupmax(PyListObject *heap, Py_ssize_t pos)
436451
childpos = 2*pos + 1; /* leftmost child position */
437452
rightpos = childpos + 1;
438453
if (rightpos < endpos) {
439-
cmp = cmp_lt(
440-
PyList_GET_ITEM(heap, rightpos),
441-
PyList_GET_ITEM(heap, childpos));
454+
PyObject* a = PyList_GET_ITEM(heap, rightpos);
455+
PyObject* b = PyList_GET_ITEM(heap, childpos);
456+
Py_INCREF(a);
457+
Py_INCREF(b);
458+
cmp = cmp_lt(a, b);
459+
Py_DECREF(a);
460+
Py_DECREF(b);
442461
if (cmp == -1) {
443462
Py_DECREF(newitem);
444463
return -1;

0 commit comments

Comments
 (0)