Skip to content

Commit ef41556

Browse files
authored
Fix order of attribute look ups for Super (#1686)
1 parent 000f784 commit ef41556

File tree

3 files changed

+71
-4
lines changed

3 files changed

+71
-4
lines changed

astroid/objects.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,11 @@ def name(self):
135135
def qname(self):
136136
return "super"
137137

138-
def igetattr(self, name, context=None):
138+
def igetattr(self, name: str, context: InferenceContext | None = None):
139139
"""Retrieve the inferred values of the given attribute name."""
140-
141-
if name in self.special_attributes:
140+
# '__class__' is a special attribute that should be taken directly
141+
# from the special attributes dict
142+
if name == "__class__":
142143
yield self.special_attributes.lookup(name)
143144
return
144145

@@ -205,6 +206,12 @@ def igetattr(self, name, context=None):
205206
else:
206207
yield bases.BoundMethod(inferred, cls)
207208

209+
# Only if we haven't found any explicit overwrites for the
210+
# attribute we look it up in the special attributes
211+
if not found and name in self.special_attributes:
212+
yield self.special_attributes.lookup(name)
213+
return
214+
208215
if not found:
209216
raise AttributeInferenceError(target=self, attribute=name, context=context)
210217

tests/unittest_inference.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3808,6 +3808,50 @@ class A: pass
38083808
with pytest.raises(InferenceError):
38093809
next(ast_node5.infer())
38103810

3811+
ast_nodes6 = extract_node(
3812+
"""
3813+
class A: pass
3814+
class B(A): pass
3815+
class C: pass
3816+
A.__new__(A) #@
3817+
A.__new__(B) #@
3818+
B.__new__(A) #@
3819+
B.__new__(B) #@
3820+
C.__new__(A) #@
3821+
"""
3822+
)
3823+
instance_A1 = next(ast_nodes6[0].infer())
3824+
assert instance_A1._proxied.name == "A"
3825+
instance_B1 = next(ast_nodes6[1].infer())
3826+
assert instance_B1._proxied.name == "B"
3827+
instance_A2 = next(ast_nodes6[2].infer())
3828+
assert instance_A2._proxied.name == "A"
3829+
instance_B2 = next(ast_nodes6[3].infer())
3830+
assert instance_B2._proxied.name == "B"
3831+
instance_A3 = next(ast_nodes6[4].infer())
3832+
assert instance_A3._proxied.name == "A"
3833+
3834+
ast_nodes7 = extract_node(
3835+
"""
3836+
import enum
3837+
class A(enum.EnumMeta): pass
3838+
class B(enum.EnumMeta):
3839+
def __new__(mcs, value, **kwargs):
3840+
return super().__new__(mcs, "str", (enum.Enum,), enum._EnumDict(), **kwargs)
3841+
class C(enum.EnumMeta):
3842+
def __new__(mcs, **kwargs):
3843+
return super().__new__(A, "str", (enum.Enum,), enum._EnumDict(), **kwargs)
3844+
B("") #@
3845+
C() #@
3846+
"""
3847+
)
3848+
instance_B = next(ast_nodes7[0].infer())
3849+
assert instance_B._proxied.name == "B"
3850+
instance_C = next(ast_nodes7[1].infer())
3851+
# TODO: This should be A. However, we don't infer EnumMeta.__new__
3852+
# correctly.
3853+
assert instance_C._proxied.name == "C"
3854+
38113855
@pytest.mark.xfail(reason="Does not support function metaclasses")
38123856
def test_function_metaclasses(self):
38133857
# These are not supported right now, although

tests/unittest_objects.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import unittest
88

9-
from astroid import bases, builder, nodes, objects
9+
from astroid import bases, builder, nodes, objects, util
1010
from astroid.exceptions import AttributeInferenceError, InferenceError, SuperError
1111
from astroid.objects import Super
1212

@@ -552,6 +552,22 @@ def foo(self): return super()
552552
super_obj = next(builder.extract_node(code).infer())
553553
self.assertEqual(super_obj.qname(), "super")
554554

555+
def test_super_new_call(self) -> None:
556+
"""Test that __new__ returns an object or node and not a (Un)BoundMethod."""
557+
new_call_result: nodes.Name = builder.extract_node(
558+
"""
559+
import enum
560+
class ChoicesMeta(enum.EnumMeta):
561+
def __new__(metacls, classname, bases, classdict, **kwds):
562+
cls = super().__new__(metacls, "str", (enum.Enum,), enum._EnumDict(), **kwargs)
563+
cls #@
564+
"""
565+
)
566+
inferred = list(new_call_result.infer())
567+
assert all(
568+
isinstance(i, (nodes.NodeNG, type(util.Uninferable))) for i in inferred
569+
)
570+
555571

556572
if __name__ == "__main__":
557573
unittest.main()

0 commit comments

Comments
 (0)