Skip to content

Commit 50a98bd

Browse files
authored
explicit Python ast.Call lowering trait (#139)
this PR removes the default behaviour of `@statement` which creates a lowering transform from `ast.Call` to the statement call. This is because `Statement` constructor is not necessarily compatible with `ast.Call` syntax, e.g when `Statement` contains a `Region` or `Block`, while most `Statement` is just a special `call`. The change here requires one explicitly declares `FromPythonCall` trait in the statement definition, which thus triggers the lowering transform to match a compatible syntax in Python AST. In the future, I think we can consider improving duplication in trait declaration by using inheritance, e.g how `BinOp` in `py.binop` dialect is currently implemented. This change is important for the following support of `with xxx as f` syntax where a given statement could have a similar roughly meaning of `with` syntax via a trait e.g `FromPythonWithSingleContext` indicates the statement has a compatible syntax with `with ... as f:` syntax where the statement's only region is treated as the with body.
1 parent 7522b59 commit 50a98bd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+382
-394
lines changed

src/kirin/decl/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from kirin.decl.emit.dialect import EmitDialect
1616
from kirin.decl.emit.property import EmitProperty
1717
from kirin.decl.emit.typecheck import EmitTypeCheck
18-
from kirin.decl.emit.from_python_call import EmitFromPythonCall
1918

2019

2120
class StatementDecl(
@@ -29,7 +28,6 @@ class StatementDecl(
2928
EmitTraits,
3029
EmitVerify,
3130
EmitTypeCheck,
32-
EmitFromPythonCall,
3331
):
3432
pass
3533

@@ -63,5 +61,5 @@ def wrap(cls):
6361
return wrap(cls)
6462

6563

66-
def fields(cls) -> info.StatementFields:
64+
def fields(cls: type[Statement]) -> info.StatementFields:
6765
return getattr(cls, ScanFields._FIELDS)

src/kirin/decl/emit/from_python_call.py

Lines changed: 0 additions & 125 deletions
This file was deleted.

src/kirin/decl/emit/traits.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@ def emit_traits(self):
1111
if hasattr(base, "traits"):
1212
return
1313
set_new_attribute(self.cls, "traits", frozenset({}))
14-
return

src/kirin/dialects/fcf/stmts.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
@statement(dialect=dialect)
77
class Foldl(ir.Statement):
8+
traits = frozenset({ir.FromPythonCall()})
89
fn: ir.SSAValue = info.argument(ir.types.PyClass(ir.Method))
910
coll: ir.SSAValue = info.argument(ir.types.Any) # TODO: make this more precise
1011
init: ir.SSAValue = info.argument(ir.types.Any)
@@ -13,6 +14,7 @@ class Foldl(ir.Statement):
1314

1415
@statement(dialect=dialect)
1516
class Foldr(ir.Statement):
17+
traits = frozenset({ir.FromPythonCall()})
1618
fn: ir.SSAValue = info.argument(ir.types.PyClass(ir.Method))
1719
coll: ir.SSAValue = info.argument(ir.types.Any)
1820
init: ir.SSAValue = info.argument(ir.types.Any)
@@ -39,6 +41,7 @@ def main():
3941
```
4042
"""
4143

44+
traits = frozenset({ir.FromPythonCall()})
4245
fn: ir.SSAValue = info.argument(ir.types.PyClass(ir.Method))
4346
"""The kernel function to apply. The function should have signature `fn(x: int) -> Any`."""
4447
coll: ir.SSAValue = info.argument(ir.types.Any)
@@ -49,6 +52,7 @@ def main():
4952

5053
@statement(dialect=dialect)
5154
class Scan(ir.Statement):
55+
traits = frozenset({ir.FromPythonCall()})
5256
fn: ir.SSAValue = info.argument(ir.types.PyClass(ir.Method))
5357
init: ir.SSAValue = info.argument(ir.types.Any)
5458
coll: ir.SSAValue = info.argument(ir.types.List)

src/kirin/dialects/math/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
"math dialect, modeling functions in python's `math` stdlib" # This file is generated by gen.py
2-
from kirin.dialects.math import interp as interp
32
from kirin.dialects.math.stmts import (
43
cos as cos,
54
erf as erf,
@@ -38,4 +37,5 @@
3837
isfinite as isfinite,
3938
remainder as remainder,
4039
)
40+
from kirin.dialects.math.interp import MathMethodTable as MathMethodTable
4141
from kirin.dialects.math.dialect import dialect as dialect

src/kirin/dialects/math/_gen.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class {name}(ir.Statement):
6565
\"\"\"{name} statement, wrapping the math.{name} function
6666
\"\"\"
6767
name = "{name}"
68-
traits = frozenset({{ir.Pure()}})
68+
traits = frozenset({{ir.Pure(), ir.FromPythonCall()}})
6969
{fields}
7070
result: ir.ResultValue = info.result(Float)
7171
"""
@@ -76,10 +76,9 @@ class {name}(ir.Statement):
7676
with open(os.path.join(os.path.dirname(__file__), "interp.py"), "w") as f:
7777
f.write("# This file is generated by gen.py\n")
7878
f.write("import math\n")
79-
f.write("from typing import Any\n")
8079
f.write("from kirin.dialects.math.dialect import dialect\n")
8180
f.write("from kirin.dialects.math import stmts\n")
82-
f.write("from kirin.interp import MethodTable, Frame, Result, impl\n")
81+
f.write("from kirin.interp import MethodTable, Frame, impl\n")
8382
f.write("\n")
8483

8584
implements = []
@@ -90,7 +89,7 @@ class {name}(ir.Statement):
9089
implements.append(
9190
f"""
9291
@impl(stmts.{name})
93-
def {name}(self, interp, frame: Frame, stmt: stmts.{name}) -> Result[Any]:
92+
def {name}(self, interp, frame: Frame, stmt: stmts.{name}):
9493
values = frame.get_values(stmt.args)
9594
return (math.{name}({fields}),)"""
9695
)
@@ -114,7 +113,9 @@ class MathMethodTable(MethodTable):
114113
for name, obj, sig in builtin_math_functions():
115114
f.write(f" {name} as {name},\n")
116115
f.write(")\n")
117-
f.write("from kirin.dialects.math.interp import Interpreter as Interpreter\n")
116+
f.write(
117+
"from kirin.dialects.math.interp import MathMethodTable as MathMethodTable\n"
118+
)
118119
f.write("\n")
119120

120121
for file in ["__init__.py", "interp.py", "stmts.py"]:

0 commit comments

Comments
 (0)