Skip to content

Commit 64790e7

Browse files
authored
Merge pull request #42 from ActiveState/BE-3657-cve-2022-48560
Be 3657 CVE 2022 48560
2 parents c001067 + 7b72978 commit 64790e7

File tree

3 files changed

+62
-7
lines changed

3 files changed

+62
-7
lines changed

Lib/test/test_heapq.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,40 @@ def test_heappop_mutating_heap(self):
396396
with self.assertRaises((IndexError, RuntimeError)):
397397
self.module.heappop(heap)
398398

399+
def test_comparison_operator_modifiying_heap(self):
400+
# See bpo-39421: Strong references need to be taken
401+
# when comparing objects as they can alter the heap
402+
class EvilClass(int):
403+
def __lt__(self, o):
404+
# heap.clear()
405+
del heap[:]
406+
return NotImplemented
407+
408+
heap = []
409+
self.module.heappush(heap, EvilClass(0))
410+
self.assertRaises(IndexError, self.module.heappushpop, heap, 1)
411+
412+
def test_comparison_operator_modifiying_heap_two_heaps(self):
413+
414+
class h(int):
415+
def __lt__(self, o):
416+
# list2.clear()
417+
del list2[:]
418+
return NotImplemented
419+
420+
class g(int):
421+
def __lt__(self, o):
422+
# list1.clear()
423+
del list1[:]
424+
return NotImplemented
425+
426+
list1, list2 = [], []
427+
428+
self.module.heappush(list1, h(0))
429+
self.module.heappush(list2, g(0))
430+
431+
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
432+
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))
399433

400434
class TestErrorHandlingPython(TestErrorHandling):
401435
module = py_heapq
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fix possible crashes when operating with the functions in the :mod:`heapq`
2+
module and custom comparison operators.

Modules/_heapqmodule.c

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@ _siftdown(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos)
5252
while (pos > startpos) {
5353
parentpos = (pos - 1) >> 1;
5454
parent = PyList_GET_ITEM(heap, parentpos);
55+
Py_INCREF(newitem);
56+
Py_INCREF(parent);
5557
cmp = cmp_lt(newitem, parent);
58+
Py_DECREF(parent);
59+
Py_DECREF(newitem);
5660
if (cmp == -1)
5761
return -1;
5862
if (size != PyList_GET_SIZE(heap)) {
@@ -93,9 +97,13 @@ _siftup(PyListObject *heap, Py_ssize_t pos)
9397
childpos = 2*pos + 1; /* leftmost child position */
9498
rightpos = childpos + 1;
9599
if (rightpos < endpos) {
96-
cmp = cmp_lt(
97-
PyList_GET_ITEM(heap, childpos),
98-
PyList_GET_ITEM(heap, rightpos));
100+
PyObject* a = PyList_GET_ITEM(heap, childpos);
101+
PyObject* b = PyList_GET_ITEM(heap, rightpos);
102+
Py_INCREF(a);
103+
Py_INCREF(b);
104+
cmp = cmp_lt(a,b);
105+
Py_DECREF(a);
106+
Py_DECREF(b);
99107
if (cmp == -1)
100108
return -1;
101109
if (cmp == 0)
@@ -236,7 +244,10 @@ heappushpop(PyObject *self, PyObject *args)
236244
return item;
237245
}
238246

239-
cmp = cmp_lt(PyList_GET_ITEM(heap, 0), item);
247+
PyObject* top = PyList_GET_ITEM(heap, 0);
248+
Py_INCREF(top);
249+
cmp = cmp_lt(top, item);
250+
Py_DECREF(top);
240251
if (cmp == -1)
241252
return NULL;
242253
if (cmp == 0) {
@@ -395,7 +406,11 @@ _siftdownmax(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos)
395406
while (pos > startpos){
396407
parentpos = (pos - 1) >> 1;
397408
parent = PyList_GET_ITEM(heap, parentpos);
409+
Py_INCREF(parent);
410+
Py_INCREF(newitem);
398411
cmp = cmp_lt(parent, newitem);
412+
Py_DECREF(parent);
413+
Py_DECREF(newitem);
399414
if (cmp == -1) {
400415
Py_DECREF(newitem);
401416
return -1;
@@ -436,9 +451,13 @@ _siftupmax(PyListObject *heap, Py_ssize_t pos)
436451
childpos = 2*pos + 1; /* leftmost child position */
437452
rightpos = childpos + 1;
438453
if (rightpos < endpos) {
439-
cmp = cmp_lt(
440-
PyList_GET_ITEM(heap, rightpos),
441-
PyList_GET_ITEM(heap, childpos));
454+
PyObject* a = PyList_GET_ITEM(heap, rightpos);
455+
PyObject* b = PyList_GET_ITEM(heap, childpos);
456+
Py_INCREF(a);
457+
Py_INCREF(b);
458+
cmp = cmp_lt(a, b);
459+
Py_DECREF(a);
460+
Py_DECREF(b);
442461
if (cmp == -1) {
443462
Py_DECREF(newitem);
444463
return -1;

0 commit comments

Comments
 (0)