Skip to content
Draft
6 changes: 5 additions & 1 deletion mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def load_type_var(self, name: str, line: int) -> Value:
)
)

def load_literal_value(self, val: int | str | bytes | float | complex | bool) -> Value:
def load_literal_value(self, val: int | str | bytes | float | complex | bool | tuple[Any, ...], dict[Any, Any]) -> Value:
"""Load value of a final name, class-level attribute, or constant folded expression."""
if isinstance(val, bool):
if val:
Expand All @@ -619,6 +619,10 @@ def load_literal_value(self, val: int | str | bytes | float | complex | bool) ->
return self.builder.load_bytes(val)
elif isinstance(val, complex):
return self.builder.load_complex(val)
elif isinstance(val, tuple):
return self.builder.load_tuple(val)
elif isinstance(val, dict):
return self.builder.load_dict(val)
else:
assert False, "Unsupported literal value"

Expand Down
29 changes: 25 additions & 4 deletions mypyc/irbuild/constant_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,30 @@

from __future__ import annotations

from typing import Final, Union
from typing import Any, Final, Union

from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_op
from mypy.nodes import (
BytesExpr,
ComplexExpr,
DictExpr,
Expression,
FloatExpr,
IntExpr,
MemberExpr,
NameExpr,
OpExpr,
StrExpr,
TupleExpr,
UnaryExpr,
Var,
)
from mypyc.irbuild.builder import IRBuilder
from mypyc.irbuild.util import bytes_from_str

# All possible result types of constant folding
ConstantValue = Union[int, float, complex, str, bytes]
CONST_TYPES: Final = (int, float, complex, str, bytes)
ConstantValue = Union[int, float, complex, str, bytes, tuple[Any, ...], dict[Any, Any]]
CONST_TYPES: Final = (int, float, complex, str, bytes, tuple, dict)


def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None:
Expand Down Expand Up @@ -72,6 +74,25 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue |
value = constant_fold_expr(builder, expr.expr)
if value is not None and not isinstance(value, bytes):
return constant_fold_unary_op(expr.op, value)
elif isinstance(expr, TupleExpr):
folded = tuple(constant_fold_expr(builder, item_expr) for item_expr in expr.items)
if None not in folded:
return folded
elif isinstance(expr, DictExpr):
# NOTE: the builder can't simply use a dict constant like it can with other constants, since dicts are mutable.
# TODO: make the builder load the dict 'constant' by calling copy on a prebuilt constant template instead of building from scratch each time
folded = {
constant_fold_expr(builder, key_expr): constant_fold_expr(builder, value_expr)
for key_expr, value_expr in expr.items
}
if (
len(folded) == len(expr.items)
and None not in folded.keys()
and None not in folded.values()
):
return folded

# TODO use a placeholder instead of None so we can include None in folded tuples/dicts
return None


Expand All @@ -82,7 +103,7 @@ def constant_fold_binary_op_extended(

mypy cannot use constant folded bytes easily so it's simpler to only support them in mypyc.
"""
if not isinstance(left, bytes) and not isinstance(right, bytes):
if not isinstance(left, (bytes, tuple, dict)) and not isinstance(right, (bytes, tuple, dict)):
return constant_fold_binary_op(op, left, right)

if op == "+" and isinstance(left, bytes) and isinstance(right, bytes):
Expand Down
11 changes: 10 additions & 1 deletion mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import sys
from collections.abc import Sequence
from typing import Callable, Final, Optional
from typing import Any, Callable, Final, Optional

from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, ArgKind
Expand Down Expand Up @@ -123,6 +123,7 @@
pointer_rprimitive,
short_int_rprimitive,
str_rprimitive,
tuple_rprimitive,
)
from mypyc.irbuild.util import concrete_arg_kind
from mypyc.options import CompilerOptions
Expand Down Expand Up @@ -1354,6 +1355,14 @@ def load_complex(self, value: complex) -> Value:
"""Load a complex literal value."""
return self.add(LoadLiteral(value, object_rprimitive))

def load_tuple(
self, value: tuple[Any, ...]
) -> Value: # should this be RTuple? conditional RTuple when length is known?
return self.add(LoadLiteral(value, tuple_rprimitive))

def load_dict(self, value: dict[Any, Any]) -> Value:
return self.add(LoadLiteral(value, dict_rprimitive))

def load_static_checked(
self,
typ: RType,
Expand Down
Loading