Skip to content

Commit 8352721

Browse files
authored
rework cf/scf lowering + rework compactify (#200)
This PR fixes a few bugs due to loop lowering. - the `lowering.Frame` is refactored to allow lowering transform fixes non-terminated blocks by pointing it to the `next_block` (e.g `lowering.cf` and `lowering.func`) - move the lowering transform for `func` dialect into separate `lowering.func` and `lowering.call` now one can optionally choose if the invoke statements is supported (partially address #155) - refactor the compactify rewrites by splitting the original rule into a few new rules with better readbility and more correctness.
1 parent 6f6d7d4 commit 8352721

File tree

35 files changed

+983
-706
lines changed

35 files changed

+983
-706
lines changed

src/kirin/dialects/cf/emit.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,20 @@ def emit_cbr(
3030
cond = frame.get(stmt.cond)
3131
interp.writeln(frame, f"if {cond}")
3232
frame.indent += 1
33-
ori = frame.get_values(stmt.then_successor.args)
3433
values = frame.get_values(stmt.then_arguments)
35-
for x, y in zip(ori, values):
34+
block_values = tuple(interp.ssa_id[x] for x in stmt.then_successor.args)
35+
frame.set_values(stmt.then_successor.args, block_values)
36+
for x, y in zip(block_values, values):
3637
interp.writeln(frame, f"{x} = {y};")
3738
interp.writeln(frame, f"@goto {interp.block_id[stmt.then_successor]};")
3839
frame.indent -= 1
3940
interp.writeln(frame, "else")
4041
frame.indent += 1
41-
ori = frame.get_values(stmt.else_successor.args)
42+
4243
values = frame.get_values(stmt.else_arguments)
43-
for x, y in zip(ori, values):
44+
block_values = tuple(interp.ssa_id[x] for x in stmt.else_successor.args)
45+
frame.set_values(stmt.else_successor.args, block_values)
46+
for x, y in zip(block_values, values):
4447
interp.writeln(frame, f"{x} = {y};")
4548
interp.writeln(frame, f"@goto {interp.block_id[stmt.else_successor]};")
4649
frame.indent -= 1

src/kirin/dialects/func/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from kirin.dialects.func import (
55
emit as emit,
6-
lower as lower,
76
interp as interp,
87
constprop as constprop,
98
typeinfer as typeinfer,

src/kirin/dialects/func/lower.py

Lines changed: 0 additions & 185 deletions
This file was deleted.

src/kirin/dialects/lowering/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
provide different lowering strategies for existing statements.
55
"""
66

7-
from . import cf as cf
7+
from . import cf as cf, call as call, func as func
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import ast
2+
3+
from kirin import ir, lowering
4+
from kirin.dialects import func
5+
from kirin.exceptions import DialectLoweringError
6+
7+
dialect = ir.Dialect("lowering.call")
8+
9+
10+
@dialect.register
11+
class Lowering(lowering.FromPythonAST):
12+
13+
def lower_Call_local(
14+
self, state: lowering.LoweringState, callee: ir.SSAValue, node: ast.Call
15+
) -> lowering.Result:
16+
args, keywords = self.__lower_Call_args_kwargs(state, node)
17+
stmt = func.Call(callee, args, kwargs=keywords)
18+
return lowering.Result(state.append_stmt(stmt))
19+
20+
def lower_Call_global_method(
21+
self,
22+
state: lowering.LoweringState,
23+
method: ir.Method,
24+
node: ast.Call,
25+
) -> lowering.Result:
26+
args, keywords = self.__lower_Call_args_kwargs(state, node)
27+
stmt = func.Invoke(args, callee=method, kwargs=keywords)
28+
stmt.result.type = method.return_type or ir.types.Any
29+
return lowering.Result(state.append_stmt(stmt))
30+
31+
def __lower_Call_args_kwargs(
32+
self,
33+
state: lowering.LoweringState,
34+
node: ast.Call,
35+
):
36+
args: list[ir.SSAValue] = []
37+
for arg in node.args:
38+
if isinstance(arg, ast.Starred): # TODO: support *args
39+
raise DialectLoweringError("starred arguments are not supported")
40+
else:
41+
args.append(state.visit(arg).expect_one())
42+
43+
keywords = []
44+
for kw in node.keywords:
45+
keywords.append(kw.arg)
46+
args.append(state.visit(kw.value).expect_one())
47+
48+
return tuple(args), tuple(keywords)

0 commit comments

Comments
 (0)