Skip to content
20 changes: 15 additions & 5 deletions mypyc/analysis/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Cast,
ComparisonOp,
ControlOp,
DecRef,
Extend,
Float,
FloatComparisonOp,
Expand All @@ -25,6 +26,7 @@
GetAttr,
GetElementPtr,
Goto,
IncRef,
InitStatic,
Integer,
IntOp,
Expand Down Expand Up @@ -77,12 +79,11 @@ def __str__(self) -> str:
return f"exits: {exits}\nsucc: {self.succ}\npred: {self.pred}"


def get_cfg(blocks: list[BasicBlock]) -> CFG:
def get_cfg(blocks: list[BasicBlock], *, use_yields: bool = False) -> CFG:
"""Calculate basic block control-flow graph.

The result is a dictionary like this:

basic block index -> (successors blocks, predecesssor blocks)
If use_yields is set, then we treat returns inserted by yields as gotos
instead of exits.
"""
succ_map = {}
pred_map: dict[BasicBlock, list[BasicBlock]] = {}
Expand All @@ -92,7 +93,10 @@ def get_cfg(blocks: list[BasicBlock]) -> CFG:
isinstance(op, ControlOp) for op in block.ops[:-1]
), "Control-flow ops must be at the end of blocks"

succ = list(block.terminator.targets())
if use_yields and isinstance(block.terminator, Return) and block.terminator.yield_target:
succ = [block.terminator.yield_target]
else:
succ = list(block.terminator.targets())
if not succ:
exits.add(block)

Expand Down Expand Up @@ -474,6 +478,12 @@ def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
return non_trivial_sources(op), set()

def visit_inc_ref(self, op: IncRef) -> GenAndKill[Value]:
return set(), set()

def visit_dec_ref(self, op: DecRef) -> GenAndKill[Value]:
return set(), set()


def analyze_live_regs(blocks: list[BasicBlock], cfg: CFG) -> AnalysisResult[Value]:
"""Calculate live registers at each CFG location.
Expand Down
11 changes: 11 additions & 0 deletions mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from mypyc.transform.flag_elimination import do_flag_elimination
from mypyc.transform.lower import lower_ir
from mypyc.transform.refcount import insert_ref_count_opcodes
from mypyc.transform.spill import insert_spills
from mypyc.transform.uninit import insert_uninit_checks

# All of the modules being compiled are divided into "groups". A group
Expand Down Expand Up @@ -228,6 +229,12 @@ def compile_scc_to_ir(
if errors.num_errors > 0:
return modules

env_user_functions = {}
for module in modules.values():
for cls in module.classes:
if cls.env_user_function:
env_user_functions[cls.env_user_function] = cls

for module in modules.values():
for fn in module.functions:
# Insert uninit checks.
Expand All @@ -236,6 +243,10 @@ def compile_scc_to_ir(
insert_exception_handling(fn)
# Insert refcount handling.
insert_ref_count_opcodes(fn)

if fn in env_user_functions:
insert_spills(fn, env_user_functions[fn])

# Switch to lower abstraction level IR.
lower_ir(fn, compiler_options)
# Perform optimizations.
Expand Down
7 changes: 7 additions & 0 deletions mypyc/ir/class_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def __init__(
# value of an attribute is the same as the error value.
self.bitmap_attrs: list[str] = []

# If this is a generator environment class, what is the actual method for it
self.env_user_function: FuncIR | None = None

def __repr__(self) -> str:
return (
"ClassIR("
Expand Down Expand Up @@ -394,6 +397,7 @@ def serialize(self) -> JsonDict:
"_always_initialized_attrs": sorted(self._always_initialized_attrs),
"_sometimes_initialized_attrs": sorted(self._sometimes_initialized_attrs),
"init_self_leak": self.init_self_leak,
"env_user_function": self.env_user_function.id if self.env_user_function else None,
}

@classmethod
Expand Down Expand Up @@ -446,6 +450,9 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ClassIR:
ir._always_initialized_attrs = set(data["_always_initialized_attrs"])
ir._sometimes_initialized_attrs = set(data["_sometimes_initialized_attrs"])
ir.init_self_leak = data["init_self_leak"]
ir.env_user_function = (
ctx.functions[data["env_user_function"]] if data["env_user_function"] else None
)

return ir

Expand Down
Loading