Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,10 @@ def add_function(self, func_ir: FuncIR, line: int) -> None:
self.function_names.add(name)
self.functions.append(func_ir)

def get_current_class_ir(self) -> ClassIR | None:
type_info = self.fn_info.fitem.info
return self.mapper.type_to_ir.get(type_info)


def gen_arg_defaults(builder: IRBuilder) -> None:
"""Generate blocks for arguments that have default values.
Expand Down
35 changes: 9 additions & 26 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
from mypyc.ir.ops import (
Assign,
BasicBlock,
Call,
ComparisonOp,
Integer,
LoadAddress,
Expand Down Expand Up @@ -98,7 +97,11 @@
join_formatted_strings,
tokenizer_printf_style,
)
from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization
from mypyc.irbuild.specialize import (
apply_function_specialization,
apply_method_specialization,
translate_object_new,
)
from mypyc.primitives.bytes_ops import bytes_slice_op
from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, exact_dict_set_item_op
from mypyc.primitives.generic_ops import iter_op, name_op
Expand Down Expand Up @@ -473,35 +476,15 @@ def translate_super_method_call(builder: IRBuilder, expr: CallExpr, callee: Supe
if callee.name in base.method_decls:
break
else:
if callee.name == "__new__":
result = translate_object_new(builder, expr, MemberExpr(callee.call, "__new__"))
if result:
return result
if ir.is_ext_class and ir.builtin_base is None and not ir.inherits_python:
if callee.name == "__init__" and len(expr.args) == 0:
# Call translates to object.__init__(self), which is a
# no-op, so omit the call.
return builder.none()
elif callee.name == "__new__":
# object.__new__(cls)
assert (
len(expr.args) == 1
), f"Expected object.__new__() call to have exactly 1 argument, got {len(expr.args)}"
typ_arg = expr.args[0]
method_args = builder.fn_info.fitem.arg_names
if (
isinstance(typ_arg, NameExpr)
and len(method_args) > 0
and method_args[0] == typ_arg.name
):
subtype = builder.accept(expr.args[0])
return builder.add(Call(ir.setup, [subtype], expr.line))

if callee.name == "__new__":
call = "super().__new__()"
if not ir.is_ext_class:
builder.error(f"{call} not supported for non-extension classes", expr.line)
if ir.inherits_python:
builder.error(
f"{call} not supported for classes inheriting from non-native classes",
expr.line,
)
return translate_call(builder, expr, callee)

decl = base.method_decl(callee.name)
Expand Down
34 changes: 34 additions & 0 deletions mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from mypy.types import AnyType, TypeOfAny
from mypyc.ir.ops import (
BasicBlock,
Call,
Extend,
Integer,
RaiseStandardError,
Expand Down Expand Up @@ -68,6 +69,7 @@
is_list_rprimitive,
is_uint8_rprimitive,
list_rprimitive,
object_rprimitive,
set_rprimitive,
str_rprimitive,
uint8_rprimitive,
Expand Down Expand Up @@ -1002,3 +1004,35 @@ def translate_ord(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value
if isinstance(arg, (StrExpr, BytesExpr)) and len(arg.value) == 1:
return Integer(ord(arg.value))
return None


@specialize_function("__new__", object_rprimitive)
def translate_object_new(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
fn = builder.fn_info
if fn.name != "__new__":
return None

ir = builder.get_current_class_ir()
if ir is None:
return None

call = "object.__new__()"
if not ir.is_ext_class:
builder.error(f"{call} not supported for non-extension classes", expr.line)
return None
if ir.inherits_python:
builder.error(
f"{call} not supported for classes inheriting from non-native classes", expr.line
)
return None
if len(expr.args) != 1:
builder.error(f"{call} supported only with 1 argument, got {len(expr.args)}", expr.line)
return None

typ_arg = expr.args[0]
method_args = fn.fitem.arg_names
if isinstance(typ_arg, NameExpr) and len(method_args) > 0 and method_args[0] == typ_arg.name:
subtype = builder.accept(expr.args[0])
return builder.add(Call(ir.setup, [subtype], expr.line))

return None
29 changes: 29 additions & 0 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ListExpr,
Lvalue,
MatchStmt,
NameExpr,
OperatorAssignmentStmt,
RaiseStmt,
ReturnStmt,
Expand Down Expand Up @@ -170,10 +171,38 @@ def transform_return_stmt(builder: IRBuilder, stmt: ReturnStmt) -> None:
builder.nonlocal_control[-1].gen_return(builder, retval, stmt.line)


def check_unsupported_cls_assignment(builder: IRBuilder, stmt: AssignmentStmt) -> None:
fn = builder.fn_info
method_args = fn.fitem.arg_names
if fn.name != "__new__" or len(method_args) == 0:
return

ir = builder.get_current_class_ir()
if ir is None or ir.inherits_python or not ir.is_ext_class:
return

cls_arg = method_args[0]
lvalues: list[Expression] = []
for lvalue in stmt.lvalues:
if isinstance(lvalue, (TupleExpr, ListExpr)):
lvalues += lvalue.items
else:
lvalues.append(lvalue)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could be a nested expression, e.g. [a, [b, c]] = ..., so this needs to be recursive.


for lvalue in lvalues:
if isinstance(lvalue, NameExpr) and lvalue.name == cls_arg:
# Disallowed because it could break the transformation of object.__new__ calls
# inside __new__ methods.
builder.error(
f"Assignment to argument {cls_arg} in __new__ method unsupported", stmt.line
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use double quotes around Python names for consistency, i.e. '... "{cls_arg}" in "__new__" method ...'.

)


def transform_assignment_stmt(builder: IRBuilder, stmt: AssignmentStmt) -> None:
lvalues = stmt.lvalues
assert lvalues
builder.disallow_class_assignments(lvalues, stmt.line)
check_unsupported_cls_assignment(builder, stmt)
first_lvalue = lvalues[0]
if stmt.type and isinstance(stmt.rvalue, TempNode):
# This is actually a variable annotation without initializer. Don't generate
Expand Down
Loading
Loading