Skip to content

Commit 2b7a2ff

Browse files
authored
split py dialect (#138)
this PR splits the python dialect into many smaller dialects allow better reusibility in other DSLs. Most importantly we split an immutable subset of list semantics from Python which is mostly used in various dialects in QuEra.
1 parent f719541 commit 2b7a2ff

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+1890
-1465
lines changed

src/kirin/decl/emit/from_python_call.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ class EmitFromPythonCall(BaseModifier):
2828
_AST_CONSTANT = "_kirin_ast_Constant"
2929

3030
def __init__(self, cls: type, **kwargs: Unpack[StatementOptions]) -> None:
31-
from kirin.dialects.py import data
31+
from kirin.dialects.py.data import PyAttr
3232

3333
super().__init__(cls, **kwargs)
3434
self.globals[self._KIRIN_RESULT] = Result
35-
self.globals[self._KIRIN_PYATTR] = data.PyAttr
35+
self.globals[self._KIRIN_PYATTR] = PyAttr
3636
self.globals[self._KIRIN_ERROR] = DialectLoweringError
3737
self.globals[self._AST_TUPLE] = ast.Tuple
3838
self.globals[self._AST_CONSTANT] = ast.Constant

src/kirin/decl/emit/property.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _emit_result_property(self, index: int, f: info.ResultField):
111111
return getter, setter
112112

113113
def _emit_attribute_property(self, f: info.AttributeField):
114-
from kirin.dialects.py import data
114+
from kirin.dialects.py.data import PyAttr
115115

116116
storage = "properties" if f.property else "attributes"
117117
attr = f"{self._self_name}.{storage}['{f.name}']"
@@ -131,7 +131,7 @@ def _emit_attribute_property(self, f: info.AttributeField):
131131
f"raise AttributeError('attribute property {f.name} is read-only')"
132132
],
133133
globals=self.globals,
134-
locals={"_value_hint": data.PyAttr if f.pytype else f.annotation},
134+
locals={"_value_hint": PyAttr if f.pytype else f.annotation},
135135
return_type=None,
136136
)
137137
else:

src/kirin/dialects/cf/lower.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from kirin.lowering import Frame, Result, FromPythonAST, LoweringState
55
from kirin.exceptions import DialectLoweringError
66
from kirin.dialects.cf import stmts as cf
7-
from kirin.dialects.py import stmts
87
from kirin.dialects.cf.dialect import dialect
98

109

@@ -19,12 +18,14 @@ def lower_Pass(self, state: LoweringState, node: ast.Pass) -> Result:
1918
return Result()
2019

2120
def lower_Assert(self, state: LoweringState, node: ast.Assert) -> Result:
21+
from kirin.dialects.py.constant import Constant
22+
2223
cond = state.visit(node.test).expect_one()
2324
if node.msg:
2425
message = state.visit(node.msg).expect_one()
2526
state.append_stmt(cf.Assert(condition=cond, message=message))
2627
else:
27-
message_stmt = state.append_stmt(stmts.Constant(""))
28+
message_stmt = state.append_stmt(Constant(""))
2829
state.append_stmt(cf.Assert(condition=cond, message=message_stmt.result))
2930
return Result()
3031

src/kirin/dialects/py/__init__.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,39 @@
1-
from . import data as data, stmts as stmts
2-
3-
# from kirin.dialects.pytypes import types as types
4-
# from kirin.dialects.pytypes.data import PyAttr as PyAttr
5-
# from kirin.dialects.pytypes.types.base import (
6-
# PyBottomType as PyBottomType,
7-
# PyClass as PyClass,
8-
# PyGeneric as PyGeneric,
9-
# PyTypeVar as PyTypeVar,
10-
# PyUnion as PyUnion,
11-
# PyVararg as PyVararg,
12-
# )
13-
# from kirin.dialects.pytypes.types.builtin import (
14-
# Bool as Bool,
15-
# Dict as Dict,
16-
# Float as Float,
17-
# FunctionType as FunctionType,
18-
# Int as Int,
19-
# List as List,
20-
# NoneType as NoneType,
21-
# String as String,
22-
# Tuple as Tuple,
23-
# TypeofFunctionType as TypeofFunctionType,
24-
# )
1+
from . import (
2+
cmp as cmp,
3+
len as len,
4+
attr as attr,
5+
data as data,
6+
binop as binop,
7+
ilist as ilist,
8+
range as range,
9+
slice as slice,
10+
tuple as tuple,
11+
unary as unary,
12+
append as append,
13+
assign as assign,
14+
boolop as boolop,
15+
builtin as builtin,
16+
constant as constant,
17+
indexing as indexing,
18+
)
19+
from .len import Len as Len
20+
from .attr import GetAttr as GetAttr
21+
from .data import PyAttr as PyAttr
22+
from .range import Range as Range
23+
from .slice import Slice as Slice
24+
from .append import Append as Append
25+
from .assign import Alias as Alias, SetItem as SetItem
26+
from .boolop import Or as Or, And as And
27+
from .builtin import Abs as Abs, Sum as Sum
28+
from .constant import Constant as Constant
29+
from .indexing import GetItem as GetItem, PyGetItemLike as PyGetItemLike
30+
from .cmp.stmts import * # noqa: F403
31+
from .binop.stmts import * # noqa: F403
32+
from .ilist.stmts import (
33+
Map as Map,
34+
Scan as Scan,
35+
FoldL as FoldL,
36+
FoldR as FoldR,
37+
ForEach as ForEach,
38+
)
39+
from .unary.stmts import * # noqa: F403

src/kirin/dialects/py/append.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Any
2+
3+
from kirin import ir
4+
from kirin.decl import info, statement
5+
from kirin.interp import Frame, Interpreter, MethodTable, impl
6+
7+
dialect = ir.Dialect("py.list")
8+
9+
ElemT = ir.types.TypeVar("ElemT")
10+
11+
12+
@statement(dialect=dialect)
13+
class Append(ir.Statement):
14+
lst: ir.SSAValue = info.argument(ir.types.List[ElemT])
15+
value: ir.SSAValue = info.argument(ir.types.Any)
16+
17+
18+
@dialect.register
19+
class MutableListMethod(MethodTable):
20+
21+
@impl(Append)
22+
def append(self, interp: Interpreter, frame: Frame[Any], stmt: Append):
23+
lst: list = frame.get(stmt.lst)
24+
value = frame.get(stmt.value)
25+
lst.append(value)
26+
return (lst,)

src/kirin/dialects/py/assign.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import ast
2+
3+
from kirin import ir, interp, lowering, exceptions
4+
from kirin.decl import info, statement
5+
from kirin.print import Printer
6+
from kirin.dialects.py import data
7+
8+
dialect = ir.Dialect("py.assign")
9+
10+
T = ir.types.TypeVar("T")
11+
12+
13+
@statement(dialect=dialect)
14+
class Alias(ir.Statement):
15+
name = "alias"
16+
traits = frozenset({ir.Pure()})
17+
value: ir.SSAValue = info.argument(T)
18+
target: data.PyAttr[str] = info.attribute(property=True)
19+
result: ir.ResultValue = info.result(T)
20+
21+
def print_impl(self, printer: Printer) -> None:
22+
printer.print_name(self)
23+
printer.plain_print(" ")
24+
with printer.rich(style=printer.color.symbol):
25+
printer.plain_print(self.target.data)
26+
27+
with printer.rich(style=printer.color.keyword):
28+
printer.plain_print(" = ")
29+
30+
printer.print(self.value)
31+
32+
33+
@statement(dialect=dialect)
34+
class SetItem(ir.Statement):
35+
name = "setitem"
36+
obj: ir.SSAValue = info.argument(print=False)
37+
value: ir.SSAValue = info.argument(print=False)
38+
index: ir.SSAValue = info.argument(print=False)
39+
40+
41+
@dialect.register
42+
class Concrete(interp.MethodTable):
43+
44+
@interp.impl(Alias)
45+
def alias(self, interp, frame: interp.Frame, stmt: Alias):
46+
return (frame.get(stmt.value),)
47+
48+
@interp.impl(SetItem)
49+
def setindex(self, interp, frame: interp.Frame, stmt: SetItem):
50+
frame.get(stmt.obj)[frame.get(stmt.index)] = frame.get(stmt.value)
51+
return (None,)
52+
53+
54+
@dialect.register
55+
class Lowering(lowering.FromPythonAST):
56+
57+
def lower_Assign(
58+
self, state: lowering.LoweringState, node: ast.Assign
59+
) -> lowering.Result:
60+
results: lowering.Result = state.visit(node.value)
61+
assert len(node.targets) == len(
62+
results
63+
), "number of targets and results do not match"
64+
65+
current_frame = state.current_frame
66+
match node:
67+
case ast.Assign(
68+
targets=[ast.Name(lhs_name, ast.Store())], value=ast.Name(_, ast.Load())
69+
):
70+
stmt = Alias(
71+
value=results[0], target=data.PyAttr(lhs_name)
72+
) # NOTE: this is guaranteed to be one result
73+
stmt.result.name = lhs_name
74+
current_frame.defs[lhs_name] = state.append_stmt(stmt).result
75+
case _:
76+
for target, value in zip(node.targets, results.values):
77+
match target:
78+
# NOTE: if the name exists new ssa value will be
79+
# used in the future to shadow the old one
80+
case ast.Name(name, ast.Store()):
81+
value.name = name
82+
current_frame.defs[name] = value
83+
case ast.Subscript(obj, slice):
84+
obj = state.visit(obj).expect_one()
85+
slice = state.visit(slice).expect_one()
86+
stmt = SetItem(obj=obj, index=slice, value=value)
87+
state.append_stmt(stmt)
88+
case _:
89+
raise exceptions.DialectLoweringError(
90+
f"unsupported target {target}"
91+
)
92+
return lowering.Result() # python assign does not have value

src/kirin/dialects/py/attr.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import ast
2+
3+
from kirin import ir, interp, lowering, exceptions
4+
from kirin.decl import info, statement
5+
6+
dialect = ir.Dialect("py.attr")
7+
8+
9+
@statement(dialect=dialect)
10+
class GetAttr(ir.Statement):
11+
name = "getattr"
12+
obj: ir.SSAValue = info.argument(print=False)
13+
attrname: str = info.attribute(property=True)
14+
result: ir.ResultValue = info.result()
15+
16+
17+
@dialect.register
18+
class Concrete(interp.MethodTable):
19+
20+
@interp.impl(GetAttr)
21+
def getattr(self, interp: interp.Interpreter, frame: interp.Frame, stmt: GetAttr):
22+
return getattr(frame.get(stmt.obj), stmt.attrname)
23+
24+
25+
@dialect.register
26+
class Lowering(lowering.FromPythonAST):
27+
28+
def lower_Attribute(
29+
self, state: lowering.LoweringState, node: ast.Attribute
30+
) -> lowering.Result:
31+
if not isinstance(node.ctx, ast.Load):
32+
raise exceptions.DialectLoweringError(
33+
f"unsupported attribute context {node.ctx}"
34+
)
35+
value = state.visit(node.value).expect_one()
36+
stmt = GetAttr(obj=value, attrname=node.attr)
37+
state.append_stmt(stmt)
38+
return lowering.Result(stmt)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from . import (
2+
julia as julia,
3+
interp as interp,
4+
lowering as lowering,
5+
typeinfer as typeinfer,
6+
)
7+
from .stmts import * # noqa: F403
8+
from ._dialect import dialect as dialect
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from kirin import ir
2+
3+
dialect = ir.Dialect("py.binop")
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from kirin import interp
2+
3+
from . import stmts
4+
from ._dialect import dialect
5+
6+
7+
@dialect.register
8+
class PyMethodTable(interp.MethodTable):
9+
10+
@interp.impl(stmts.Add)
11+
def add(self, interp, frame: interp.Frame, stmt: stmts.Add):
12+
return (frame.get(stmt.lhs) + frame.get(stmt.rhs),)
13+
14+
@interp.impl(stmts.Sub)
15+
def sub(self, interp, frame: interp.Frame, stmt: stmts.Sub):
16+
return (frame.get(stmt.lhs) - frame.get(stmt.rhs),)
17+
18+
@interp.impl(stmts.Mult)
19+
def mult(self, interp, frame: interp.Frame, stmt: stmts.Mult):
20+
return (frame.get(stmt.lhs) * frame.get(stmt.rhs),)
21+
22+
@interp.impl(stmts.Div)
23+
def div(self, interp, frame: interp.Frame, stmt: stmts.Div):
24+
return (frame.get(stmt.lhs) / frame.get(stmt.rhs),)
25+
26+
@interp.impl(stmts.Mod)
27+
def mod(self, interp, frame: interp.Frame, stmt: stmts.Mod):
28+
return (frame.get(stmt.lhs) % frame.get(stmt.rhs),)
29+
30+
@interp.impl(stmts.BitAnd)
31+
def bit_and(self, interp, frame: interp.Frame, stmt: stmts.BitAnd):
32+
return (frame.get(stmt.lhs) & frame.get(stmt.rhs),)
33+
34+
@interp.impl(stmts.BitOr)
35+
def bit_or(self, interp, frame: interp.Frame, stmt: stmts.BitOr):
36+
return (frame.get(stmt.lhs) | frame.get(stmt.rhs),)
37+
38+
@interp.impl(stmts.BitXor)
39+
def bit_xor(self, interp, frame: interp.Frame, stmt: stmts.BitXor):
40+
return (frame.get(stmt.lhs) ^ frame.get(stmt.rhs),)
41+
42+
@interp.impl(stmts.LShift)
43+
def lshift(self, interp, frame: interp.Frame, stmt: stmts.LShift):
44+
return (frame.get(stmt.lhs) << frame.get(stmt.rhs),)
45+
46+
@interp.impl(stmts.RShift)
47+
def rshift(self, interp, frame: interp.Frame, stmt: stmts.RShift):
48+
return (frame.get(stmt.lhs) >> frame.get(stmt.rhs),)
49+
50+
@interp.impl(stmts.FloorDiv)
51+
def floor_div(self, interp, frame: interp.Frame, stmt: stmts.FloorDiv):
52+
return (frame.get(stmt.lhs) // frame.get(stmt.rhs),)
53+
54+
@interp.impl(stmts.Pow)
55+
def pow(self, interp, frame: interp.Frame, stmt: stmts.Pow):
56+
return (frame.get(stmt.lhs) ** frame.get(stmt.rhs),)
57+
58+
@interp.impl(stmts.MatMult)
59+
def mat_mult(self, interp, frame: interp.Frame, stmt: stmts.MatMult):
60+
return (frame.get(stmt.lhs) @ frame.get(stmt.rhs),)

0 commit comments

Comments
 (0)