Skip to content

Commit 8eac14e

Browse files
committed
Address review comments
1 parent a05ea96 commit 8eac14e

File tree

4 files changed

+154
-52
lines changed

4 files changed

+154
-52
lines changed

mypyc/irbuild/builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,6 +1437,10 @@ def add_function(self, func_ir: FuncIR, line: int) -> None:
14371437
self.function_names.add(name)
14381438
self.functions.append(func_ir)
14391439

1440+
def get_current_class_ir(self) -> ClassIR | None:
1441+
type_info = self.fn_info.fitem.info
1442+
return self.mapper.type_to_ir.get(type_info)
1443+
14401444

14411445
def gen_arg_defaults(builder: IRBuilder) -> None:
14421446
"""Generate blocks for arguments that have default values.

mypyc/irbuild/specialize.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,10 +1012,7 @@ def translate_object_new(builder: IRBuilder, expr: CallExpr, callee: RefExpr) ->
10121012
if fn.name != "__new__":
10131013
return None
10141014

1015-
ir = None
1016-
for cls in builder.classes:
1017-
if cls.name == fn.class_name:
1018-
ir = cls
1015+
ir = builder.get_current_class_ir()
10191016
if ir is None:
10201017
return None
10211018

mypyc/irbuild/statement.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ListExpr,
3434
Lvalue,
3535
MatchStmt,
36+
NameExpr,
3637
OperatorAssignmentStmt,
3738
RaiseStmt,
3839
ReturnStmt,
@@ -170,10 +171,38 @@ def transform_return_stmt(builder: IRBuilder, stmt: ReturnStmt) -> None:
170171
builder.nonlocal_control[-1].gen_return(builder, retval, stmt.line)
171172

172173

174+
def check_unsupported_cls_assignment(builder: IRBuilder, stmt: AssignmentStmt) -> None:
175+
fn = builder.fn_info
176+
method_args = fn.fitem.arg_names
177+
if fn.name != "__new__" or len(method_args) == 0:
178+
return
179+
180+
ir = builder.get_current_class_ir()
181+
if ir is None or ir.inherits_python or not ir.is_ext_class:
182+
return
183+
184+
cls_arg = method_args[0]
185+
lvalues: list[Expression] = []
186+
for lvalue in stmt.lvalues:
187+
if isinstance(lvalue, (TupleExpr, ListExpr)):
188+
lvalues += lvalue.items
189+
else:
190+
lvalues.append(lvalue)
191+
192+
for lvalue in lvalues:
193+
if isinstance(lvalue, NameExpr) and lvalue.name == cls_arg:
194+
# Disallowed because it could break the transformation of object.__new__ calls
195+
# inside __new__ methods.
196+
builder.error(
197+
f"Assignment to argument {cls_arg} in __new__ method unsupported", stmt.line
198+
)
199+
200+
173201
def transform_assignment_stmt(builder: IRBuilder, stmt: AssignmentStmt) -> None:
174202
lvalues = stmt.lvalues
175203
assert lvalues
176204
builder.disallow_class_assignments(lvalues, stmt.line)
205+
check_unsupported_cls_assignment(builder, stmt)
177206
first_lvalue = lvalues[0]
178207
if stmt.type and isinstance(stmt.rvalue, TempNode):
179208
# This is actually a variable annotation without initializer. Don't generate

mypyc/test-data/irbuild-classes.test

Lines changed: 120 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1691,8 +1691,8 @@ class NotTransformed:
16911691
def __new__(cls, val: int) -> Any:
16921692
return super().__new__(str)
16931693

1694-
@classmethod
1695-
def factory(cls, val: int) -> NotTransformed:
1694+
def factory(cls: Any, val: int) -> Any:
1695+
cls = str
16961696
return super().__new__(cls)
16971697

16981698
[out]
@@ -1766,38 +1766,39 @@ L0:
17661766
def NotTransformed.factory(cls, val):
17671767
cls :: object
17681768
val :: int
1769-
r0 :: object
1770-
r1 :: str
1771-
r2, r3 :: object
1772-
r4 :: object[2]
1773-
r5 :: object_ptr
1774-
r6 :: object
1775-
r7 :: str
1776-
r8 :: object
1777-
r9 :: object[1]
1778-
r10 :: object_ptr
1779-
r11 :: object
1780-
r12 :: __main__.NotTransformed
1769+
r0, r1 :: object
1770+
r2 :: str
1771+
r3, r4 :: object
1772+
r5 :: object[2]
1773+
r6 :: object_ptr
1774+
r7 :: object
1775+
r8 :: str
1776+
r9 :: object
1777+
r10 :: object[1]
1778+
r11 :: object_ptr
1779+
r12 :: object
17811780
L0:
1782-
r0 = builtins :: module
1783-
r1 = 'super'
1784-
r2 = CPyObject_GetAttr(r0, r1)
1785-
r3 = __main__.NotTransformed :: type
1786-
r4 = [r3, cls]
1787-
r5 = load_address r4
1788-
r6 = PyObject_Vectorcall(r2, r5, 2, 0)
1789-
keep_alive r3, cls
1790-
r7 = '__new__'
1791-
r8 = CPyObject_GetAttr(r6, r7)
1792-
r9 = [cls]
1793-
r10 = load_address r9
1794-
r11 = PyObject_Vectorcall(r8, r10, 1, 0)
1781+
r0 = load_address PyUnicode_Type
1782+
cls = r0
1783+
r1 = builtins :: module
1784+
r2 = 'super'
1785+
r3 = CPyObject_GetAttr(r1, r2)
1786+
r4 = __main__.NotTransformed :: type
1787+
r5 = [r4, cls]
1788+
r6 = load_address r5
1789+
r7 = PyObject_Vectorcall(r3, r6, 2, 0)
1790+
keep_alive r4, cls
1791+
r8 = '__new__'
1792+
r9 = CPyObject_GetAttr(r7, r8)
1793+
r10 = [cls]
1794+
r11 = load_address r10
1795+
r12 = PyObject_Vectorcall(r9, r11, 1, 0)
17951796
keep_alive cls
1796-
r12 = cast(__main__.NotTransformed, r11)
17971797
return r12
17981798

1799-
[case testObjectDunderNew]
1799+
[case testObjectDunderNew_64bit]
18001800
from __future__ import annotations
1801+
from mypy_extensions import mypyc_attr
18011802
from typing import Any
18021803

18031804
class Test:
@@ -1827,10 +1828,21 @@ class NotTransformed:
18271828
def __new__(cls, val: int) -> Any:
18281829
return object.__new__(str)
18291830

1830-
@classmethod
1831-
def factory(cls, val: int) -> NotTransformed:
1831+
def factory(cls: Any, val: int) -> Any:
1832+
cls = str
18321833
return object.__new__(cls)
18331834

1835+
@mypyc_attr(native_class=False)
1836+
class NonNative:
1837+
def __new__(cls: Any) -> Any:
1838+
cls = str
1839+
return cls("str")
1840+
1841+
class InheritsPython(dict):
1842+
def __new__(cls: Any) -> Any:
1843+
cls = dict
1844+
return cls({})
1845+
18341846
class ObjectNewOutsideDunderNew:
18351847
def __init__(self) -> None:
18361848
object.__new__(ObjectNewOutsideDunderNew)
@@ -1899,25 +1911,69 @@ L0:
18991911
def NotTransformed.factory(cls, val):
19001912
cls :: object
19011913
val :: int
1902-
r0 :: object
1903-
r1 :: str
1904-
r2 :: object
1905-
r3 :: str
1906-
r4 :: object[2]
1907-
r5 :: object_ptr
1908-
r6 :: object
1909-
r7 :: __main__.NotTransformed
1914+
r0, r1 :: object
1915+
r2 :: str
1916+
r3 :: object
1917+
r4 :: str
1918+
r5 :: object[2]
1919+
r6 :: object_ptr
1920+
r7 :: object
19101921
L0:
1911-
r0 = builtins :: module
1912-
r1 = 'object'
1913-
r2 = CPyObject_GetAttr(r0, r1)
1914-
r3 = '__new__'
1915-
r4 = [r2, cls]
1916-
r5 = load_address r4
1917-
r6 = PyObject_VectorcallMethod(r3, r5, 9223372036854775810, 0)
1918-
keep_alive r2, cls
1919-
r7 = cast(__main__.NotTransformed, r6)
1922+
r0 = load_address PyUnicode_Type
1923+
cls = r0
1924+
r1 = builtins :: module
1925+
r2 = 'object'
1926+
r3 = CPyObject_GetAttr(r1, r2)
1927+
r4 = '__new__'
1928+
r5 = [r3, cls]
1929+
r6 = load_address r5
1930+
r7 = PyObject_VectorcallMethod(r4, r6, 9223372036854775810, 0)
1931+
keep_alive r3, cls
19201932
return r7
1933+
def __new___NonNative_obj.__get__(__mypyc_self__, instance, owner):
1934+
__mypyc_self__, instance, owner, r0 :: object
1935+
r1 :: bit
1936+
r2 :: object
1937+
L0:
1938+
r0 = load_address _Py_NoneStruct
1939+
r1 = instance == r0
1940+
if r1 goto L1 else goto L2 :: bool
1941+
L1:
1942+
return __mypyc_self__
1943+
L2:
1944+
r2 = PyMethod_New(__mypyc_self__, instance)
1945+
return r2
1946+
def __new___NonNative_obj.__call__(__mypyc_self__, cls):
1947+
__mypyc_self__ :: __main__.__new___NonNative_obj
1948+
cls, r0 :: object
1949+
r1 :: str
1950+
r2 :: object[1]
1951+
r3 :: object_ptr
1952+
r4 :: object
1953+
L0:
1954+
r0 = load_address PyUnicode_Type
1955+
cls = r0
1956+
r1 = 'str'
1957+
r2 = [r1]
1958+
r3 = load_address r2
1959+
r4 = PyObject_Vectorcall(cls, r3, 1, 0)
1960+
keep_alive r1
1961+
return r4
1962+
def InheritsPython.__new__(cls):
1963+
cls, r0 :: object
1964+
r1 :: dict
1965+
r2 :: object[1]
1966+
r3 :: object_ptr
1967+
r4 :: object
1968+
L0:
1969+
r0 = load_address PyDict_Type
1970+
cls = r0
1971+
r1 = PyDict_New()
1972+
r2 = [r1]
1973+
r3 = load_address r2
1974+
r4 = PyObject_Vectorcall(cls, r3, 1, 0)
1975+
keep_alive r1
1976+
return r4
19211977
def ObjectNewOutsideDunderNew.__init__(self):
19221978
self :: __main__.ObjectNewOutsideDunderNew
19231979
r0 :: object
@@ -1961,6 +2017,7 @@ L0:
19612017
[case testUnsupportedDunderNew]
19622018
from __future__ import annotations
19632019
from mypy_extensions import mypyc_attr
2020+
from typing import Any
19642021

19652022
@mypyc_attr(native_class=False)
19662023
class NonNative:
@@ -1980,6 +2037,21 @@ class InheritsPythonObjectNew(dict):
19802037
def __new__(cls) -> InheritsPythonObjectNew:
19812038
return object.__new__(cls) # E: object.__new__() not supported for classes inheriting from non-native classes
19822039

2040+
class ClsAssignment:
2041+
def __new__(cls: Any) -> Any:
2042+
cls = str # E: Assignment to argument cls in __new__ method unsupported
2043+
return super().__new__(cls)
2044+
2045+
class ClsTupleAssignment:
2046+
def __new__(class_i_want: Any, val: int) -> Any:
2047+
class_i_want, val = dict, 1 # E: Assignment to argument class_i_want in __new__ method unsupported
2048+
return object.__new__(class_i_want)
2049+
2050+
class ClsListAssignment:
2051+
def __new__(cls: Any, val: str) -> Any:
2052+
[cls, val] = [object, "object"] # E: Assignment to argument cls in __new__ method unsupported
2053+
return object.__new__(cls)
2054+
19832055
[case testClassWithFreeList]
19842056
from mypy_extensions import mypyc_attr, trait
19852057

0 commit comments

Comments
 (0)