Skip to content

Commit 971c56f

Browse files
committed
[mypyc] Transform object.__new__ inside __new__
1 parent dce8e1c commit 971c56f

File tree

4 files changed

+443
-28
lines changed

4 files changed

+443
-28
lines changed

mypyc/irbuild/expression.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
from mypyc.ir.ops import (
5858
Assign,
5959
BasicBlock,
60-
Call,
6160
ComparisonOp,
6261
Integer,
6362
LoadAddress,
@@ -98,7 +97,11 @@
9897
join_formatted_strings,
9998
tokenizer_printf_style,
10099
)
101-
from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization
100+
from mypyc.irbuild.specialize import (
101+
apply_function_specialization,
102+
apply_method_specialization,
103+
translate_object_new,
104+
)
102105
from mypyc.primitives.bytes_ops import bytes_slice_op
103106
from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, exact_dict_set_item_op
104107
from mypyc.primitives.generic_ops import iter_op, name_op
@@ -473,35 +476,15 @@ def translate_super_method_call(builder: IRBuilder, expr: CallExpr, callee: Supe
473476
if callee.name in base.method_decls:
474477
break
475478
else:
479+
if callee.name == "__new__":
480+
result = translate_object_new(builder, expr, MemberExpr(callee.call, "__new__"))
481+
if result:
482+
return result
476483
if ir.is_ext_class and ir.builtin_base is None and not ir.inherits_python:
477484
if callee.name == "__init__" and len(expr.args) == 0:
478485
# Call translates to object.__init__(self), which is a
479486
# no-op, so omit the call.
480487
return builder.none()
481-
elif callee.name == "__new__":
482-
# object.__new__(cls)
483-
assert (
484-
len(expr.args) == 1
485-
), f"Expected object.__new__() call to have exactly 1 argument, got {len(expr.args)}"
486-
typ_arg = expr.args[0]
487-
method_args = builder.fn_info.fitem.arg_names
488-
if (
489-
isinstance(typ_arg, NameExpr)
490-
and len(method_args) > 0
491-
and method_args[0] == typ_arg.name
492-
):
493-
subtype = builder.accept(expr.args[0])
494-
return builder.add(Call(ir.setup, [subtype], expr.line))
495-
496-
if callee.name == "__new__":
497-
call = "super().__new__()"
498-
if not ir.is_ext_class:
499-
builder.error(f"{call} not supported for non-extension classes", expr.line)
500-
if ir.inherits_python:
501-
builder.error(
502-
f"{call} not supported for classes inheriting from non-native classes",
503-
expr.line,
504-
)
505488
return translate_call(builder, expr, callee)
506489

507490
decl = base.method_decl(callee.name)

mypyc/irbuild/specialize.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from mypy.types import AnyType, TypeOfAny
3737
from mypyc.ir.ops import (
3838
BasicBlock,
39+
Call,
3940
Extend,
4041
Integer,
4142
RaiseStandardError,
@@ -68,6 +69,7 @@
6869
is_list_rprimitive,
6970
is_uint8_rprimitive,
7071
list_rprimitive,
72+
object_rprimitive,
7173
set_rprimitive,
7274
str_rprimitive,
7375
uint8_rprimitive,
@@ -1002,3 +1004,38 @@ def translate_ord(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value
10021004
if isinstance(arg, (StrExpr, BytesExpr)) and len(arg.value) == 1:
10031005
return Integer(ord(arg.value))
10041006
return None
1007+
1008+
1009+
@specialize_function("__new__", object_rprimitive)
1010+
def translate_object_new(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
1011+
fn = builder.fn_info
1012+
if fn.name != "__new__":
1013+
return None
1014+
1015+
ir = None
1016+
for cls in builder.classes:
1017+
if cls.name == fn.class_name:
1018+
ir = cls
1019+
if ir is None:
1020+
return None
1021+
1022+
call = "object.__new__()"
1023+
if not ir.is_ext_class:
1024+
builder.error(f"{call} not supported for non-extension classes", expr.line)
1025+
return None
1026+
if ir.inherits_python:
1027+
builder.error(
1028+
f"{call} not supported for classes inheriting from non-native classes", expr.line
1029+
)
1030+
return None
1031+
1032+
assert (
1033+
len(expr.args) == 1
1034+
), f"Expected object.__new__() call to have exactly 1 argument, got {len(expr.args)}"
1035+
typ_arg = expr.args[0]
1036+
method_args = fn.fitem.arg_names
1037+
if isinstance(typ_arg, NameExpr) and len(method_args) > 0 and method_args[0] == typ_arg.name:
1038+
subtype = builder.accept(expr.args[0])
1039+
return builder.add(Call(ir.setup, [subtype], expr.line))
1040+
1041+
return None

mypyc/test-data/irbuild-classes.test

Lines changed: 247 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1662,6 +1662,7 @@ L0:
16621662

16631663
[case testDunderNew]
16641664
from __future__ import annotations
1665+
from typing import Any
16651666

16661667
class Test:
16671668
val: int
@@ -1686,6 +1687,14 @@ class NewClassMethod:
16861687
def fn2() -> NewClassMethod:
16871688
return NewClassMethod.__new__(42)
16881689

1690+
class NotTransformed:
1691+
def __new__(cls, val: int) -> Any:
1692+
return super().__new__(str)
1693+
1694+
@classmethod
1695+
def factory(cls, val: int) -> NotTransformed:
1696+
return super().__new__(cls)
1697+
16891698
[out]
16901699
def Test.__new__(cls, val):
16911700
cls :: object
@@ -1721,6 +1730,233 @@ L0:
17211730
r0 = __main__.NewClassMethod :: type
17221731
r1 = NewClassMethod.__new__(r0, 84)
17231732
return r1
1733+
def NotTransformed.__new__(cls, val):
1734+
cls :: object
1735+
val :: int
1736+
r0 :: object
1737+
r1 :: str
1738+
r2, r3 :: object
1739+
r4 :: object[2]
1740+
r5 :: object_ptr
1741+
r6 :: object
1742+
r7 :: str
1743+
r8, r9 :: object
1744+
r10 :: object[1]
1745+
r11 :: object_ptr
1746+
r12 :: object
1747+
r13 :: str
1748+
L0:
1749+
r0 = builtins :: module
1750+
r1 = 'super'
1751+
r2 = CPyObject_GetAttr(r0, r1)
1752+
r3 = __main__.NotTransformed :: type
1753+
r4 = [r3, cls]
1754+
r5 = load_address r4
1755+
r6 = PyObject_Vectorcall(r2, r5, 2, 0)
1756+
keep_alive r3, cls
1757+
r7 = '__new__'
1758+
r8 = CPyObject_GetAttr(r6, r7)
1759+
r9 = load_address PyUnicode_Type
1760+
r10 = [r9]
1761+
r11 = load_address r10
1762+
r12 = PyObject_Vectorcall(r8, r11, 1, 0)
1763+
keep_alive r9
1764+
r13 = cast(str, r12)
1765+
return r13
1766+
def NotTransformed.factory(cls, val):
1767+
cls :: object
1768+
val :: int
1769+
r0 :: object
1770+
r1 :: str
1771+
r2, r3 :: object
1772+
r4 :: object[2]
1773+
r5 :: object_ptr
1774+
r6 :: object
1775+
r7 :: str
1776+
r8 :: object
1777+
r9 :: object[1]
1778+
r10 :: object_ptr
1779+
r11 :: object
1780+
r12 :: __main__.NotTransformed
1781+
L0:
1782+
r0 = builtins :: module
1783+
r1 = 'super'
1784+
r2 = CPyObject_GetAttr(r0, r1)
1785+
r3 = __main__.NotTransformed :: type
1786+
r4 = [r3, cls]
1787+
r5 = load_address r4
1788+
r6 = PyObject_Vectorcall(r2, r5, 2, 0)
1789+
keep_alive r3, cls
1790+
r7 = '__new__'
1791+
r8 = CPyObject_GetAttr(r6, r7)
1792+
r9 = [cls]
1793+
r10 = load_address r9
1794+
r11 = PyObject_Vectorcall(r8, r10, 1, 0)
1795+
keep_alive cls
1796+
r12 = cast(__main__.NotTransformed, r11)
1797+
return r12
1798+
1799+
[case testObjectDunderNew]
1800+
from __future__ import annotations
1801+
from typing import Any
1802+
1803+
class Test:
1804+
val: int
1805+
1806+
def __new__(cls, val: int) -> Test:
1807+
obj = object.__new__(cls)
1808+
obj.val = val
1809+
return obj
1810+
1811+
def fn() -> Test:
1812+
return Test.__new__(Test, 42)
1813+
1814+
class NewClassMethod:
1815+
val: int
1816+
1817+
@classmethod
1818+
def __new__(cls, val: int) -> NewClassMethod:
1819+
obj = object.__new__(cls)
1820+
obj.val = val
1821+
return obj
1822+
1823+
def fn2() -> NewClassMethod:
1824+
return NewClassMethod.__new__(42)
1825+
1826+
class NotTransformed:
1827+
def __new__(cls, val: int) -> Any:
1828+
return object.__new__(str)
1829+
1830+
@classmethod
1831+
def factory(cls, val: int) -> NotTransformed:
1832+
return object.__new__(cls)
1833+
1834+
class ObjectNewOutsideDunderNew:
1835+
def __init__(self) -> None:
1836+
object.__new__(ObjectNewOutsideDunderNew)
1837+
1838+
def object_new_outside_class() -> None:
1839+
object.__new__(Test)
1840+
1841+
[out]
1842+
def Test.__new__(cls, val):
1843+
cls :: object
1844+
val :: int
1845+
r0, obj :: __main__.Test
1846+
r1 :: bool
1847+
L0:
1848+
r0 = __mypyc__Test_setup(cls)
1849+
obj = r0
1850+
obj.val = val; r1 = is_error
1851+
return obj
1852+
def fn():
1853+
r0 :: object
1854+
r1 :: __main__.Test
1855+
L0:
1856+
r0 = __main__.Test :: type
1857+
r1 = Test.__new__(r0, 84)
1858+
return r1
1859+
def NewClassMethod.__new__(cls, val):
1860+
cls :: object
1861+
val :: int
1862+
r0, obj :: __main__.NewClassMethod
1863+
r1 :: bool
1864+
L0:
1865+
r0 = __mypyc__NewClassMethod_setup(cls)
1866+
obj = r0
1867+
obj.val = val; r1 = is_error
1868+
return obj
1869+
def fn2():
1870+
r0 :: object
1871+
r1 :: __main__.NewClassMethod
1872+
L0:
1873+
r0 = __main__.NewClassMethod :: type
1874+
r1 = NewClassMethod.__new__(r0, 84)
1875+
return r1
1876+
def NotTransformed.__new__(cls, val):
1877+
cls :: object
1878+
val :: int
1879+
r0 :: object
1880+
r1 :: str
1881+
r2, r3 :: object
1882+
r4 :: str
1883+
r5 :: object[2]
1884+
r6 :: object_ptr
1885+
r7 :: object
1886+
r8 :: str
1887+
L0:
1888+
r0 = builtins :: module
1889+
r1 = 'object'
1890+
r2 = CPyObject_GetAttr(r0, r1)
1891+
r3 = load_address PyUnicode_Type
1892+
r4 = '__new__'
1893+
r5 = [r2, r3]
1894+
r6 = load_address r5
1895+
r7 = PyObject_VectorcallMethod(r4, r6, 9223372036854775810, 0)
1896+
keep_alive r2, r3
1897+
r8 = cast(str, r7)
1898+
return r8
1899+
def NotTransformed.factory(cls, val):
1900+
cls :: object
1901+
val :: int
1902+
r0 :: object
1903+
r1 :: str
1904+
r2 :: object
1905+
r3 :: str
1906+
r4 :: object[2]
1907+
r5 :: object_ptr
1908+
r6 :: object
1909+
r7 :: __main__.NotTransformed
1910+
L0:
1911+
r0 = builtins :: module
1912+
r1 = 'object'
1913+
r2 = CPyObject_GetAttr(r0, r1)
1914+
r3 = '__new__'
1915+
r4 = [r2, cls]
1916+
r5 = load_address r4
1917+
r6 = PyObject_VectorcallMethod(r3, r5, 9223372036854775810, 0)
1918+
keep_alive r2, cls
1919+
r7 = cast(__main__.NotTransformed, r6)
1920+
return r7
1921+
def ObjectNewOutsideDunderNew.__init__(self):
1922+
self :: __main__.ObjectNewOutsideDunderNew
1923+
r0 :: object
1924+
r1 :: str
1925+
r2, r3 :: object
1926+
r4 :: str
1927+
r5 :: object[2]
1928+
r6 :: object_ptr
1929+
r7 :: object
1930+
L0:
1931+
r0 = builtins :: module
1932+
r1 = 'object'
1933+
r2 = CPyObject_GetAttr(r0, r1)
1934+
r3 = __main__.ObjectNewOutsideDunderNew :: type
1935+
r4 = '__new__'
1936+
r5 = [r2, r3]
1937+
r6 = load_address r5
1938+
r7 = PyObject_VectorcallMethod(r4, r6, 9223372036854775810, 0)
1939+
keep_alive r2, r3
1940+
return 1
1941+
def object_new_outside_class():
1942+
r0 :: object
1943+
r1 :: str
1944+
r2, r3 :: object
1945+
r4 :: str
1946+
r5 :: object[2]
1947+
r6 :: object_ptr
1948+
r7 :: object
1949+
L0:
1950+
r0 = builtins :: module
1951+
r1 = 'object'
1952+
r2 = CPyObject_GetAttr(r0, r1)
1953+
r3 = __main__.Test :: type
1954+
r4 = '__new__'
1955+
r5 = [r2, r3]
1956+
r6 = load_address r5
1957+
r7 = PyObject_VectorcallMethod(r4, r6, 9223372036854775810, 0)
1958+
keep_alive r2, r3
1959+
return 1
17241960

17251961
[case testUnsupportedDunderNew]
17261962
from __future__ import annotations
@@ -1729,11 +1965,20 @@ from mypy_extensions import mypyc_attr
17291965
@mypyc_attr(native_class=False)
17301966
class NonNative:
17311967
def __new__(cls) -> NonNative:
1732-
return super().__new__(cls) # E: super().__new__() not supported for non-extension classes
1968+
return super().__new__(cls) # E: object.__new__() not supported for non-extension classes
17331969

17341970
class InheritsPython(dict):
17351971
def __new__(cls) -> InheritsPython:
1736-
return super().__new__(cls) # E: super().__new__() not supported for classes inheriting from non-native classes
1972+
return super().__new__(cls) # E: object.__new__() not supported for classes inheriting from non-native classes
1973+
1974+
@mypyc_attr(native_class=False)
1975+
class NonNativeObjectNew:
1976+
def __new__(cls) -> NonNativeObjectNew:
1977+
return object.__new__(cls) # E: object.__new__() not supported for non-extension classes
1978+
1979+
class InheritsPythonObjectNew(dict):
1980+
def __new__(cls) -> InheritsPythonObjectNew:
1981+
return object.__new__(cls) # E: object.__new__() not supported for classes inheriting from non-native classes
17371982

17381983
[case testClassWithFreeList]
17391984
from mypy_extensions import mypyc_attr, trait

0 commit comments

Comments
 (0)