Skip to content

Commit e78a067

Browse files
committed
gh-92810: Avoid O(n^2) complexity in ABCMeta.__subclasscheck__
Signed-off-by: Martynov Maxim <[email protected]>
1 parent c432d01 commit e78a067

File tree

4 files changed

+162
-48
lines changed

4 files changed

+162
-48
lines changed

Lib/_py_abc.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from _weakrefset import WeakSet
22

3+
_UNSET = object()
4+
35

46
def get_cache_token():
57
"""Returns the current ABC cache token.
@@ -65,8 +67,23 @@ def register(cls, subclass):
6567
if issubclass(cls, subclass):
6668
# This would create a cycle, which is bad for the algorithm below
6769
raise RuntimeError("Refusing to create an inheritance cycle")
70+
71+
# Actual registration
6872
cls._abc_registry.add(subclass)
69-
ABCMeta._abc_invalidation_counter += 1 # Invalidate negative cache
73+
74+
# Recursively register the subclass in all ABC bases, to avoid recursive lookups.
75+
# >>> class Ancestor1(ABC): pass
76+
# >>> class Ancestor2(Ancestor1): pass
77+
# >>> class Other: pass
78+
# >>> Ancestor2.register(Other) # same result for Ancestor1.register(Other)
79+
# >>> issubclass(Other, Ancestor2) is True
80+
# >>> issubclass(Other, Ancestor1) is True
81+
for pcls in cls.__mro__:
82+
if hasattr(pcls, "_abc_registry"):
83+
pcls._abc_registry.add(subclass)
84+
85+
# Invalidate negative cache
86+
ABCMeta._abc_invalidation_counter += 1
7087
return subclass
7188

7289
def _dump_registry(cls, file=None):
@@ -137,11 +154,19 @@ def __subclasscheck__(cls, subclass):
137154
if issubclass(subclass, rcls):
138155
cls._abc_cache.add(subclass)
139156
return True
140-
# Check if it's a subclass of a subclass (recursive)
141-
for scls in cls.__subclasses__():
142-
if issubclass(subclass, scls):
143-
cls._abc_cache.add(subclass)
144-
return True
157+
158+
# Check if it's a subclass of a subclass (recursive).
159+
# >>> class Ancestor: __subclasses__ = lambda: [Other]
160+
# >>> class Other: pass
161+
# >>> isinstance(Other, Ancestor) is True
162+
# Do not iterate over cls.__subclasses__() because it returns the entire class tree,
163+
# not just direct children, which leads to O(n^2) lookup.
164+
original_subclasses = getattr(cls, "__dict__", {}).get("__subclasses__", _UNSET)
165+
if original_subclasses is not _UNSET:
166+
for scls in original_subclasses():
167+
if issubclass(subclass, scls):
168+
cls._abc_cache.add(subclass)
169+
return True
145170
# No dice; update negative cache
146171
cls._abc_negative_cache.add(subclass)
147172
return False

Lib/test/test_abc.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,35 @@ class MyInt(int):
411411
self.assertIsInstance(42, A)
412412
self.assertIsInstance(42, (A,))
413413

414-
def test_issubclass_bad_arguments(self):
414+
def test_subclasses(self):
415+
class A:
416+
pass
417+
418+
class B:
419+
pass
420+
421+
class C:
422+
pass
423+
424+
class Sup(metaclass=abc_ABCMeta):
425+
__subclasses__ = lambda: [A, B]
426+
427+
self.assertIsSubclass(A, Sup)
428+
self.assertIsSubclass(A, (Sup,))
429+
self.assertIsInstance(A(), Sup)
430+
self.assertIsInstance(A(), (Sup,))
431+
432+
self.assertIsSubclass(B, Sup)
433+
self.assertIsSubclass(B, (Sup,))
434+
self.assertIsInstance(B(), Sup)
435+
self.assertIsInstance(B(), (Sup,))
436+
437+
self.assertNotIsSubclass(C, Sup)
438+
self.assertNotIsSubclass(C, (Sup,))
439+
self.assertNotIsInstance(C(), Sup)
440+
self.assertNotIsInstance(C(), (Sup,))
441+
442+
def test_subclasses_bad_arguments(self):
415443
class A(metaclass=abc_ABCMeta):
416444
pass
417445

Lib/test/test_abstract_numbers.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ def not_implemented(*args, **kwargs):
2424

2525
class TestNumbers(unittest.TestCase):
2626
def test_int(self):
27-
self.assertTrue(issubclass(int, Integral))
28-
self.assertTrue(issubclass(int, Rational))
29-
self.assertTrue(issubclass(int, Real))
30-
self.assertTrue(issubclass(int, Complex))
31-
self.assertTrue(issubclass(int, Number))
27+
self.assertIsSubclass(int, Integral)
28+
self.assertIsSubclass(int, Rational)
29+
self.assertIsSubclass(int, Real)
30+
self.assertIsSubclass(int, Complex)
31+
self.assertIsSubclass(int, Number)
3232

3333
self.assertEqual(7, int(7).real)
3434
self.assertEqual(0, int(7).imag)
@@ -38,23 +38,23 @@ def test_int(self):
3838
self.assertEqual(1, int(7).denominator)
3939

4040
def test_float(self):
41-
self.assertFalse(issubclass(float, Integral))
42-
self.assertFalse(issubclass(float, Rational))
43-
self.assertTrue(issubclass(float, Real))
44-
self.assertTrue(issubclass(float, Complex))
45-
self.assertTrue(issubclass(float, Number))
41+
self.assertNotIsSubclass(float, Integral)
42+
self.assertNotIsSubclass(float, Rational)
43+
self.assertIsSubclass(float, Real)
44+
self.assertIsSubclass(float, Complex)
45+
self.assertIsSubclass(float, Number)
4646

4747
self.assertEqual(7.3, float(7.3).real)
4848
self.assertEqual(0, float(7.3).imag)
4949
self.assertEqual(7.3, float(7.3).conjugate())
5050
self.assertEqual(-7.3, float(-7.3).conjugate())
5151

5252
def test_complex(self):
53-
self.assertFalse(issubclass(complex, Integral))
54-
self.assertFalse(issubclass(complex, Rational))
55-
self.assertFalse(issubclass(complex, Real))
56-
self.assertTrue(issubclass(complex, Complex))
57-
self.assertTrue(issubclass(complex, Number))
53+
self.assertNotIsSubclass(complex, Integral)
54+
self.assertNotIsSubclass(complex, Rational)
55+
self.assertNotIsSubclass(complex, Real)
56+
self.assertIsSubclass(complex, Complex)
57+
self.assertIsSubclass(complex, Number)
5858

5959
c1, c2 = complex(3, 2), complex(4,1)
6060
# XXX: This is not ideal, but see the comment in math_trunc().

Modules/_abc.c

Lines changed: 87 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,8 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
578578
if (result < 0) {
579579
return NULL;
580580
}
581+
582+
/* Actual registration */
581583
_abc_data *impl = _get_impl(module, self);
582584
if (impl == NULL) {
583585
return NULL;
@@ -588,6 +590,49 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
588590
}
589591
Py_DECREF(impl);
590592

593+
/* Recursively register the subclass in all ABC bases, to avoid recursive lookups.
594+
>>> class Ancestor1(ABC): pass
595+
>>> class Ancestor2(Ancestor1): pass
596+
>>> class Other: pass
597+
>>> Ancestor2.register(Other) # same result for Ancestor1.register(Other)
598+
>>> issubclass(Other, Ancestor2) is True
599+
>>> issubclass(Other, Ancestor1) is True
600+
*/
601+
PyObject *mro = PyObject_GetAttrString(self, "__mro__");
602+
if (mro == NULL) {
603+
return NULL;
604+
}
605+
606+
if (!PyTuple_Check(mro)) {
607+
PyErr_SetString(PyExc_TypeError, "__mro__ is not tuple");
608+
goto error;
609+
}
610+
611+
for (Py_ssize_t pos = 0; pos < PyTuple_GET_SIZE(mro); pos++) {
612+
PyObject *base_class = PyTuple_GET_ITEM(mro, pos); // borrowed
613+
PyObject *base_class_data;
614+
615+
if (PyObject_GetOptionalAttr(base_class, &_Py_ID(_abc_impl),
616+
&base_class_data) < 0) {
617+
goto error;
618+
}
619+
620+
if (PyErr_Occurred()) {
621+
goto error;
622+
}
623+
624+
if (base_class_data == NULL) {
625+
// not ABC class
626+
continue;
627+
}
628+
629+
_abc_data *base_class_state = _abc_data_CAST(base_class_data);
630+
if (_add_to_weak_set(base_class_state, &base_class_state->_abc_registry, subclass) < 0) {
631+
Py_DECREF(base_class_data);
632+
goto error;
633+
}
634+
}
635+
591636
/* Invalidate negative cache */
592637
increment_invalidation_counter(get_abc_state(module));
593638

@@ -602,6 +647,10 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
602647
}
603648
}
604649
return Py_NewRef(subclass);
650+
651+
error:
652+
Py_XDECREF(mro);
653+
return NULL;
605654
}
606655

607656

@@ -710,6 +759,7 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
710759
PyErr_SetString(PyExc_TypeError, "issubclass() arg 1 must be a class");
711760
return NULL;
712761
}
762+
PyTypeObject *cls = (PyTypeObject *)self;
713763

714764
PyObject *ok, *subclasses = NULL, *result = NULL;
715765
_abcmodule_state *state = NULL;
@@ -800,32 +850,43 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
800850
goto end;
801851
}
802852

803-
/* 6. Check if it's a subclass of a subclass (recursive). */
804-
subclasses = PyObject_CallMethod(self, "__subclasses__", NULL);
805-
if (subclasses == NULL) {
806-
goto end;
807-
}
808-
if (!PyList_Check(subclasses)) {
809-
PyErr_SetString(PyExc_TypeError, "__subclasses__() must return a list");
810-
goto end;
811-
}
812-
for (pos = 0; pos < PyList_GET_SIZE(subclasses); pos++) {
813-
PyObject *scls = PyList_GetItemRef(subclasses, pos);
814-
if (scls == NULL) {
815-
goto end;
816-
}
817-
int r = PyObject_IsSubclass(subclass, scls);
818-
Py_DECREF(scls);
819-
if (r > 0) {
820-
if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
821-
goto end;
822-
}
823-
result = Py_True;
824-
goto end;
825-
}
826-
if (r < 0) {
827-
goto end;
828-
}
853+
/* 6. Check if it's a subclass of a subclass (recursive).
854+
>>> class Ancestor: __subclasses__ = lambda: [Other]
855+
>>> class Other: pass
856+
>>> isinstance(Other, Ancestor) is True
857+
858+
Do not iterate over cls.__subclasses__() because it returns the entire class tree,
859+
not just direct children, which leads to O(n^2) lookup.
860+
*/
861+
PyObject *dict = _PyType_GetDict(cls); // borrowed
862+
PyObject *subclasses_own_method = PyDict_GetItemString(dict, "__subclasses__"); // borrowed
863+
if (subclasses_own_method) {
864+
subclasses = PyObject_CallNoArgs(subclasses_own_method);
865+
if (subclasses == NULL) {
866+
goto end;
867+
}
868+
if (!PyList_Check(subclasses)) {
869+
PyErr_SetString(PyExc_TypeError, "__subclasses__() must return a list");
870+
goto end;
871+
}
872+
for (pos = 0; pos < PyList_GET_SIZE(subclasses); pos++) {
873+
PyObject *scls = PyList_GetItemRef(subclasses, pos);
874+
if (scls == NULL) {
875+
goto end;
876+
}
877+
int r = PyObject_IsSubclass(subclass, scls);
878+
Py_DECREF(scls);
879+
if (r > 0) {
880+
if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
881+
goto end;
882+
}
883+
result = Py_True;
884+
goto end;
885+
}
886+
if (r < 0) {
887+
goto end;
888+
}
889+
}
829890
}
830891

831892
/* No dice; update negative cache. */

0 commit comments

Comments
 (0)