Skip to content

Commit cab2f70

Browse files
committed
Call setup function only when first argument matches outer function
1 parent a20c884 commit cab2f70

File tree

4 files changed

+74
-10
lines changed

4 files changed

+74
-10
lines changed

mypyc/codegen/emitclass.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,8 +678,10 @@ def emit_null_check() -> None:
678678
emit_null_check()
679679
return
680680
prefix = emitter.get_group_prefix(new_fn.decl) + NATIVE_PREFIX if native_prefix else PREFIX
681-
new_args = ", ".join([type_arg, new_args])
682-
emitter.emit_line(f"PyObject *self = {prefix}{new_fn.cname(emitter.names)}({new_args});")
681+
all_args = type_arg
682+
if new_args != "":
683+
all_args += ", " + new_args
684+
emitter.emit_line(f"PyObject *self = {prefix}{new_fn.cname(emitter.names)}({all_args});")
683685
emit_null_check()
684686

685687
# skip __init__ if __new__ returns some other type

mypyc/irbuild/expression.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,8 +483,15 @@ def translate_super_method_call(builder: IRBuilder, expr: CallExpr, callee: Supe
483483
assert (
484484
len(expr.args) == 1
485485
), f"Expected object.__new__() call to have exactly 1 argument, got {len(expr.args)}"
486-
subtype = builder.accept(expr.args[0])
487-
return builder.add(Call(ir.setup, [subtype], expr.line))
486+
typ_arg = expr.args[0]
487+
method_args = builder.fn_info.fitem.arg_names
488+
if (
489+
isinstance(typ_arg, NameExpr)
490+
and len(method_args) > 0
491+
and method_args[0] == typ_arg.name
492+
):
493+
subtype = builder.accept(expr.args[0])
494+
return builder.add(Call(ir.setup, [subtype], expr.line))
488495

489496
if callee.name == "__new__":
490497
call = "super().__new__()"

mypyc/test-data/irbuild-classes.test

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,6 +1619,21 @@ class Test:
16191619
obj.val = val
16201620
return obj
16211621

1622+
def fn() -> Test:
1623+
return Test.__new__(Test, 42)
1624+
1625+
class NewClassMethod:
1626+
val: int
1627+
1628+
@classmethod
1629+
def __new__(cls, val: int) -> NewClassMethod:
1630+
obj = super().__new__(cls)
1631+
obj.val = val
1632+
return obj
1633+
1634+
def fn2() -> NewClassMethod:
1635+
return NewClassMethod.__new__(42)
1636+
16221637
[out]
16231638
def Test.__new__(cls, val):
16241639
cls :: object
@@ -1630,6 +1645,30 @@ L0:
16301645
obj = r0
16311646
obj.val = val; r1 = is_error
16321647
return obj
1648+
def fn():
1649+
r0 :: object
1650+
r1 :: __main__.Test
1651+
L0:
1652+
r0 = __main__.Test :: type
1653+
r1 = Test.__new__(r0, 84)
1654+
return r1
1655+
def NewClassMethod.__new__(cls, val):
1656+
cls :: object
1657+
val :: int
1658+
r0, obj :: __main__.NewClassMethod
1659+
r1 :: bool
1660+
L0:
1661+
r0 = __mypyc__NewClassMethod_setup(cls)
1662+
obj = r0
1663+
obj.val = val; r1 = is_error
1664+
return obj
1665+
def fn2():
1666+
r0 :: object
1667+
r1 :: __main__.NewClassMethod
1668+
L0:
1669+
r0 = __main__.NewClassMethod :: type
1670+
r1 = NewClassMethod.__new__(r0, 84)
1671+
return r1
16331672

16341673
[case testUnsupportedDunderNew]
16351674
from __future__ import annotations

mypyc/test-data/run-classes.test

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3310,6 +3310,10 @@ class RaisesException:
33103310
def __init__(self, val: int) -> None:
33113311
self.val = val
33123312

3313+
class ClsArgNotPassed:
3314+
def __new__(cls) -> Any:
3315+
return super().__new__(str)
3316+
33133317
def test_dunder_new() -> None:
33143318
add_instance: Any = Add(1, 5)
33153319
assert type(add_instance) == Add
@@ -3332,6 +3336,9 @@ def test_dunder_new() -> None:
33323336
not_raised = RaisesException(1)
33333337
assert not_raised.val == 1
33343338

3339+
with assertRaises(TypeError, "object.__new__(str) is not safe, use str.__new__()"):
3340+
str_as_cls = ClsArgNotPassed()
3341+
33353342

33363343
[case testDunderNewInInterpreted]
33373344
from __future__ import annotations
@@ -3368,8 +3375,12 @@ class RaisesException:
33683375
def __init__(self, val: int) -> None:
33693376
self.val = val
33703377

3378+
class ClsArgNotPassed:
3379+
def __new__(cls) -> Any:
3380+
return super().__new__(str)
3381+
33713382
[file driver.py]
3372-
from native import Add, RaisesException
3383+
from native import Add, ClsArgNotPassed, RaisesException
33733384

33743385
from testutil import assertRaises
33753386

@@ -3383,6 +3394,9 @@ with assertRaises(RuntimeError, "Invalid value!"):
33833394
not_raised = RaisesException(1)
33843395
assert not_raised.val == 1
33853396

3397+
with assertRaises(TypeError, "object.__new__(str) is not safe, use str.__new__()"):
3398+
str_as_cls = ClsArgNotPassed()
3399+
33863400
[out]
33873401
running __new__ with 1 and 5
33883402
Add(1, 5)=(1 + 5)
@@ -3394,14 +3408,15 @@ Add(1, 0)=1
33943408
[case testInheritedDunderNew]
33953409
from __future__ import annotations
33963410
from mypy_extensions import mypyc_attr
3411+
from typing_extensions import Self
33973412

33983413
from m import interpreted_subclass
33993414

34003415
@mypyc_attr(allow_interpreted_subclasses=True)
34013416
class Base:
34023417
val: int
34033418

3404-
def __new__(cls, val: int):
3419+
def __new__(cls, val: int) -> Self:
34053420
obj = super().__new__(cls)
34063421
obj.val = val + 1
34073422
return obj
@@ -3410,7 +3425,7 @@ class Base:
34103425
self.init_val = val
34113426

34123427
class Sub(Base):
3413-
def __new__(cls, val: int):
3428+
def __new__(cls, val: int) -> Self:
34143429
return super().__new__(cls, val + 1)
34153430

34163431
def __init__(self, val: int) -> None:
@@ -3425,7 +3440,7 @@ class SubWithoutNew(Base):
34253440
class BaseWithoutInterpretedSubclasses:
34263441
val: int
34273442

3428-
def __new__(cls, val: int):
3443+
def __new__(cls, val: int) -> Self:
34293444
obj = super().__new__(cls)
34303445
obj.val = val + 1
34313446
return obj
@@ -3434,7 +3449,7 @@ class BaseWithoutInterpretedSubclasses:
34343449
self.init_val = val
34353450

34363451
class SubNoInterpreted(BaseWithoutInterpretedSubclasses):
3437-
def __new__(cls, val: int):
3452+
def __new__(cls, val: int) -> Self:
34383453
return super().__new__(cls, val + 1)
34393454

34403455
def __init__(self, val: int) -> None:
@@ -3483,6 +3498,7 @@ def test_interpreted_subclass() -> None:
34833498

34843499
[file m.py]
34853500
from __future__ import annotations
3501+
from typing_extensions import Self
34863502

34873503
def interpreted_subclass(base) -> None:
34883504
b = base(42)
@@ -3491,7 +3507,7 @@ def interpreted_subclass(base) -> None:
34913507
assert b.init_val == 42
34923508

34933509
class InterpretedSub(base):
3494-
def __new__(cls, val: int) -> base:
3510+
def __new__(cls, val: int) -> Self:
34953511
return super().__new__(cls, val + 1)
34963512

34973513
def __init__(self, val: int) -> None:

0 commit comments

Comments
 (0)