Skip to content
20 changes: 15 additions & 5 deletions mypyc/analysis/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Cast,
ComparisonOp,
ControlOp,
DecRef,
Extend,
Float,
FloatComparisonOp,
Expand All @@ -25,6 +26,7 @@
GetAttr,
GetElementPtr,
Goto,
IncRef,
InitStatic,
Integer,
IntOp,
Expand Down Expand Up @@ -79,12 +81,11 @@ def __str__(self) -> str:
return "\n".join(lines)


def get_cfg(blocks: list[BasicBlock]) -> CFG:
def get_cfg(blocks: list[BasicBlock], *, use_yields: bool = False) -> CFG:
"""Calculate basic block control-flow graph.

The result is a dictionary like this:

basic block index -> (successors blocks, predecesssor blocks)
If use_yields is set, then we treat returns inserted by yields as gotos
instead of exits.
"""
succ_map = {}
pred_map: dict[BasicBlock, list[BasicBlock]] = {}
Expand All @@ -94,7 +95,10 @@ def get_cfg(blocks: list[BasicBlock]) -> CFG:
isinstance(op, ControlOp) for op in block.ops[:-1]
), "Control-flow ops must be at the end of blocks"

succ = list(block.terminator.targets())
if use_yields and isinstance(block.terminator, Return) and block.terminator.yield_target:
succ = [block.terminator.yield_target]
else:
succ = list(block.terminator.targets())
if not succ:
exits.add(block)

Expand Down Expand Up @@ -494,6 +498,12 @@ def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
return non_trivial_sources(op), set()

def visit_inc_ref(self, op: IncRef) -> GenAndKill[Value]:
return set(), set()

def visit_dec_ref(self, op: DecRef) -> GenAndKill[Value]:
return set(), set()


def analyze_live_regs(blocks: list[BasicBlock], cfg: CFG) -> AnalysisResult[Value]:
"""Calculate live registers at each CFG location.
Expand Down
5 changes: 5 additions & 0 deletions mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from mypyc.options import CompilerOptions
from mypyc.transform.exceptions import insert_exception_handling
from mypyc.transform.refcount import insert_ref_count_opcodes
from mypyc.transform.spill import insert_spills
from mypyc.transform.uninit import insert_uninit_checks

# All of the modules being compiled are divided into "groups". A group
Expand Down Expand Up @@ -237,6 +238,10 @@ def compile_scc_to_ir(
for module in modules.values():
for fn in module.functions:
insert_ref_count_opcodes(fn)
for module in modules.values():
for cls in module.classes:
if cls.env_user_function:
insert_spills(cls.env_user_function, cls)

return modules

Expand Down
7 changes: 7 additions & 0 deletions mypyc/ir/class_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ def __init__(
# value of an attribute is the same as the error value.
self.bitmap_attrs: list[str] = []

# If this is a generator environment class, what is the actual method for it
self.env_user_function: FuncIR | None = None

def __repr__(self) -> str:
return (
"ClassIR("
Expand Down Expand Up @@ -391,6 +394,7 @@ def serialize(self) -> JsonDict:
"_always_initialized_attrs": sorted(self._always_initialized_attrs),
"_sometimes_initialized_attrs": sorted(self._sometimes_initialized_attrs),
"init_self_leak": self.init_self_leak,
"env_user_function": self.env_user_function.id if self.env_user_function else None,
}

@classmethod
Expand Down Expand Up @@ -442,6 +446,9 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ClassIR:
ir._always_initialized_attrs = set(data["_always_initialized_attrs"])
ir._sometimes_initialized_attrs = set(data["_sometimes_initialized_attrs"])
ir.init_self_leak = data["init_self_leak"]
ir.env_user_function = (
ctx.functions[data["env_user_function"]] if data["env_user_function"] else None
)

return ir

Expand Down
Loading