Skip to content

Commit bf5052d

Browse files
authored
FromPythonWith trait (#142)
1 parent d9f03b0 commit bf5052d

File tree

6 files changed

+249
-54
lines changed

6 files changed

+249
-54
lines changed

src/kirin/ir/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@
3333
NoTerminator as NoTerminator,
3434
SSACFGRegion as SSACFGRegion,
3535
FromPythonCall as FromPythonCall,
36+
FromPythonWith as FromPythonWith,
3637
IsolatedFromAbove as IsolatedFromAbove,
3738
SymbolOpInterface as SymbolOpInterface,
3839
PythonLoweringTrait as PythonLoweringTrait,
3940
CallableStmtInterface as CallableStmtInterface,
41+
FromPythonWithSingleItem as FromPythonWithSingleItem,
4042
)
4143
from kirin.ir.dialect import Dialect as Dialect

src/kirin/ir/traits/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,7 @@
1818
)
1919
from .lowering.call import FromPythonCall as FromPythonCall
2020
from .region.ssacfg import SSACFGRegion as SSACFGRegion
21+
from .lowering.context import (
22+
FromPythonWith as FromPythonWith,
23+
FromPythonWithSingleItem as FromPythonWithSingleItem,
24+
)

src/kirin/ir/traits/lowering/call.py

Lines changed: 16 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
from typing import TYPE_CHECKING, TypeVar
33
from dataclasses import dataclass
44

5-
from kirin.exceptions import DialectLoweringError
6-
75
from ..abc import PythonLoweringTrait
86

97
if TYPE_CHECKING:
@@ -16,60 +14,24 @@
1614

1715
@dataclass(frozen=True)
1816
class FromPythonCall(PythonLoweringTrait[StatementType, ast.Call]):
17+
"""Trait for customizing lowering of Python calls to a statement.
18+
19+
Declared in a statement definition to indicate that the statement can be
20+
constructed from a Python call (i.e., a function call `ast.Call` in the
21+
Python AST).
22+
23+
Subclassing this trait allows for customizing the lowering of Python calls
24+
to the statement. The `lower` method should be implemented to parse the
25+
arguments from the Python call and construct the statement instance.
26+
"""
1927

2028
def lower(
2129
self, stmt: type[StatementType], state: "LoweringState", node: ast.Call
2230
) -> "Result":
23-
from kirin.decl import fields
24-
from kirin.lowering import Result
25-
from kirin.dialects.py.data import PyAttr
26-
27-
fs = fields(stmt)
28-
stmt_std_arg_names = fs.std_args.keys()
29-
stmt_kw_args_name = fs.kw_args.keys()
30-
stmt_attr_prop_names = fs.attr_or_props
31-
stmt_required_names = fs.required_names
32-
stmt_group_arg_names = fs.group_arg_names
33-
args, kwargs = {}, {}
34-
for name, value in zip(stmt_std_arg_names, node.args):
35-
self._parse_arg(stmt_group_arg_names, state, args, name, value)
36-
for kw in node.keywords:
37-
if not isinstance(kw.arg, str):
38-
raise DialectLoweringError("Expected string for keyword argument name")
39-
40-
arg: str = kw.arg
41-
if arg in node.args:
42-
raise DialectLoweringError(
43-
f"Keyword argument {arg} is already present in positional arguments"
44-
)
45-
elif arg in stmt_std_arg_names or arg in stmt_kw_args_name:
46-
self._parse_arg(stmt_group_arg_names, state, kwargs, kw.arg, kw.value)
47-
elif arg in stmt_attr_prop_names:
48-
if not isinstance(kw.value, ast.Constant):
49-
raise DialectLoweringError(
50-
f"Expected constant for attribute or property {arg}"
51-
)
52-
kwargs[arg] = PyAttr(kw.value.value)
53-
else:
54-
raise DialectLoweringError(f"Unexpected keyword argument {arg}")
55-
56-
for name in stmt_required_names:
57-
if name not in args and name not in kwargs:
58-
raise DialectLoweringError(f"Missing required argument {name}")
59-
60-
return Result(state.append_stmt(stmt(*args.values(), **kwargs)))
31+
return state.default_Call_lower(stmt, node)
6132

62-
@staticmethod
63-
def _parse_arg(
64-
group_names: set[str],
65-
state: "LoweringState",
66-
target: dict,
67-
name: str,
68-
value: ast.AST,
69-
):
70-
if name in group_names:
71-
if not isinstance(value, ast.Tuple):
72-
raise DialectLoweringError(f"Expected tuple for group argument {name}")
73-
target[name] = tuple(state.visit(elem).expect_one() for elem in value.elts)
74-
else:
75-
target[name] = state.visit(value).expect_one()
33+
def verify(self, stmt: "Statement"):
34+
assert len(stmt.regions) == 0, "FromPythonCall statements cannot have regions"
35+
assert (
36+
len(stmt.successors) == 0
37+
), "FromPythonCall statements cannot have successors"
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""Traits for customizing lowering of Python `with` syntax to a statement.
2+
"""
3+
4+
import ast
5+
from typing import TYPE_CHECKING, TypeVar
6+
from dataclasses import dataclass
7+
8+
from kirin.exceptions import DialectLoweringError
9+
10+
from ..abc import PythonLoweringTrait
11+
12+
if TYPE_CHECKING:
13+
from kirin.ir import Statement
14+
from kirin.lowering import Result, LoweringState
15+
16+
StatementType = TypeVar("StatementType", bound="Statement")
17+
18+
19+
@dataclass(frozen=True)
20+
class FromPythonWith(PythonLoweringTrait[StatementType, ast.With]):
21+
"""Trait for customizing lowering of Python with statements to a statement.
22+
23+
Subclassing this trait allows for customizing the lowering of Python with
24+
statements to the statement. The `lower` method should be implemented to parse
25+
the arguments from the Python with statement and construct the statement instance.
26+
"""
27+
28+
pass
29+
30+
31+
@dataclass(frozen=True)
32+
class FromPythonWithSingleItem(FromPythonWith[StatementType]):
33+
"""Trait for customizing lowering of the following Python with syntax to a statement:
34+
35+
```python
36+
with <stmt>[ as <name>]:
37+
<body>
38+
```
39+
40+
where `<stmt>` is the statement being lowered, `<name>` is an optional name for the result
41+
of the statement, and `<body>` is the body of the with statement. The optional `as <name>`
42+
is not valid when the statement has no results.
43+
44+
This syntax is slightly different from the standard Python `with` statement in that
45+
`<name>` refers to the result of the statement, not the context manager. Thus typically
46+
one sould access `<name>` in `<body>` to use the result of the statement.
47+
48+
In some cases, however, `<name>` may be used as a reference of a special value `self` that
49+
is passed to the `<body>` of the statement. This is useful for statements that have a similar
50+
behavior to a closure.
51+
"""
52+
53+
def lower(
54+
self, stmt: type[StatementType], state: "LoweringState", node: ast.With
55+
) -> "Result":
56+
from kirin import lowering
57+
from kirin.decl import fields
58+
59+
fs = fields(stmt)
60+
if len(fs.regions) != 1:
61+
raise DialectLoweringError(
62+
"Expected exactly one region in statement declaration"
63+
)
64+
65+
if len(node.items) != 1:
66+
raise DialectLoweringError("Expected exactly one item in statement")
67+
68+
item, body = node.items[0], node.body
69+
if not isinstance(item.context_expr, ast.Call):
70+
raise DialectLoweringError(
71+
f"Expected context expression to be a call for with {stmt.name}"
72+
)
73+
74+
body_frame = lowering.Frame.from_stmts(body, state, parent=state.current_frame)
75+
state.push_frame(body_frame)
76+
state.exhaust()
77+
state.pop_frame()
78+
79+
args, kwargs = state.default_Call_inputs(stmt, item.context_expr)
80+
(region_name,) = fs.regions
81+
kwargs[region_name] = body_frame.current_region
82+
results = state.append_stmt(stmt(*args.values(), **kwargs)).results
83+
if len(results) == 0:
84+
return lowering.Result()
85+
elif len(results) > 1:
86+
raise DialectLoweringError(
87+
f"Expected exactly one result or no result from statement {stmt.name}"
88+
)
89+
90+
result = results[0]
91+
if item.optional_vars is not None and isinstance(item.optional_vars, ast.Name):
92+
result.name = item.optional_vars.id
93+
state.current_frame.defs[result.name] = result
94+
return lowering.Result(result)
95+
96+
def verify(self, stmt: "Statement"):
97+
assert (
98+
len(stmt.regions) == 1
99+
), "FromPythonWithSingleItem statements must have one region"
100+
assert (
101+
len(stmt.successors) == 0
102+
), "FromPythonWithSingleItem statements cannot have successors"
103+
assert (
104+
len(stmt.results) <= 1
105+
), "FromPythonWithSingleItem statements can have at most one result"

src/kirin/lowering/state.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,38 @@ def visit(self, node: ast.AST) -> Result:
112112
# it will be called first before __dispatch_Call
113113
# because "Call" exists in self.registry
114114
return self.__dispatch_Call(node)
115+
elif isinstance(node, ast.With):
116+
return self.__dispatch_With(node)
115117
return super().visit(node)
116118

117119
def generic_visit(self, node: ast.AST):
118120
raise DialectLoweringError(f"unsupported ast node {type(node)}:")
119121

122+
def __dispatch_With(self, node: ast.With):
123+
if len(node.items) != 1:
124+
raise DialectLoweringError("expected exactly one item in with statement")
125+
126+
item = node.items[0]
127+
if not isinstance(item.context_expr, ast.Call):
128+
raise DialectLoweringError("expected context expression to be a call")
129+
130+
global_callee_result = self.get_global_nothrow(item.context_expr.func)
131+
if global_callee_result is None:
132+
raise DialectLoweringError("cannot find call func in with context")
133+
134+
global_callee = global_callee_result.unwrap()
135+
if not issubclass(global_callee, Statement):
136+
raise DialectLoweringError("expected callee to be a statement")
137+
138+
if (
139+
trait := global_callee.get_trait(traits.FromPythonWithSingleItem)
140+
) is not None:
141+
return trait.lower(global_callee, self, node)
142+
143+
raise DialectLoweringError(
144+
"unsupported callee, missing FromPythonWithSingleItem trait"
145+
)
146+
120147
def __dispatch_Call(self, node: ast.Call):
121148
# 1. try to lookup global statement object
122149
# 2. lookup local values
@@ -196,6 +223,63 @@ def __lower_Call_local(self, node: ast.Call) -> Result:
196223
return self.registry["Call_local"].lower_Call_local(self, callee, node)
197224
raise DialectLoweringError("`lower_Call_local` not implemented")
198225

226+
def default_Call_lower(self, stmt: type[Statement], node: ast.Call) -> Result:
227+
"""Default lowering for Python call to statement.
228+
229+
This method is intended to be used by traits like `FromPythonCall` to
230+
provide a default lowering for Python calls to statements.
231+
232+
Args:
233+
stmt(type[Statement]): Statement class to construct.
234+
node(ast.Call): Python call node to lower.
235+
236+
Returns:
237+
Result: Result of lowering the Python call to statement.
238+
"""
239+
args, kwargs = self.default_Call_inputs(stmt, node)
240+
return Result(self.append_stmt(stmt(*args.values(), **kwargs)))
241+
242+
def default_Call_inputs(
243+
self, stmt: type[Statement], node: ast.Call
244+
) -> tuple[dict[str, SSAValue | tuple[SSAValue, ...]], dict[str, Any]]:
245+
from kirin.decl import fields
246+
from kirin.dialects.py.data import PyAttr
247+
248+
fs = fields(stmt)
249+
stmt_std_arg_names = fs.std_args.keys()
250+
stmt_kw_args_name = fs.kw_args.keys()
251+
stmt_attr_prop_names = fs.attr_or_props
252+
stmt_required_names = fs.required_names
253+
stmt_group_arg_names = fs.group_arg_names
254+
args, kwargs = {}, {}
255+
for name, value in zip(stmt_std_arg_names, node.args):
256+
self._parse_arg(stmt_group_arg_names, args, name, value)
257+
for kw in node.keywords:
258+
if not isinstance(kw.arg, str):
259+
raise DialectLoweringError("Expected string for keyword argument name")
260+
261+
arg: str = kw.arg
262+
if arg in node.args:
263+
raise DialectLoweringError(
264+
f"Keyword argument {arg} is already present in positional arguments"
265+
)
266+
elif arg in stmt_std_arg_names or arg in stmt_kw_args_name:
267+
self._parse_arg(stmt_group_arg_names, kwargs, kw.arg, kw.value)
268+
elif arg in stmt_attr_prop_names:
269+
if not isinstance(kw.value, ast.Constant):
270+
raise DialectLoweringError(
271+
f"Expected constant for attribute or property {arg}"
272+
)
273+
kwargs[arg] = PyAttr(kw.value.value)
274+
else:
275+
raise DialectLoweringError(f"Unexpected keyword argument {arg}")
276+
277+
for name in stmt_required_names:
278+
if name not in args and name not in kwargs:
279+
raise DialectLoweringError(f"Missing required argument {name}")
280+
281+
return args, kwargs
282+
199283
def _parse_arg(
200284
self,
201285
group_names: set[str],

test/lowering/test_with.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from kirin import ir
2+
from kirin.decl import info, statement
3+
from kirin.prelude import python_no_opt
4+
from kirin.dialects import cf, py, func
5+
from kirin.lowering import Lowering
6+
7+
dialect = ir.Dialect("test")
8+
9+
10+
@statement(dialect=dialect)
11+
class Adjoint(ir.Statement):
12+
traits = frozenset({ir.FromPythonWithSingleItem()})
13+
body: ir.Region = info.region()
14+
result: ir.ResultValue = info.result()
15+
16+
17+
def with_example(x):
18+
y = 1
19+
with Adjoint() as f: # type: ignore
20+
y = x + 1
21+
return y, f
22+
23+
24+
def test_with_lowering():
25+
lowering = Lowering(python_no_opt.union([cf, func, dialect]))
26+
code = lowering.run(with_example)
27+
code.print()
28+
assert isinstance(code, func.Function)
29+
stmts = code.body.blocks[0].stmts
30+
assert isinstance(stmts.at(0), py.Constant)
31+
adjoint = stmts.at(1)
32+
assert isinstance(adjoint, Adjoint)
33+
assert len(adjoint.body.blocks) == 1
34+
add = adjoint.body.blocks[0].stmts.at(1)
35+
assert isinstance(add, py.Add)
36+
assert isinstance(add.lhs, ir.BlockArgument)
37+
assert isinstance(add.rhs, ir.SSAValue)
38+
assert adjoint.result.name == "f"

0 commit comments

Comments
 (0)