diff --git a/src/kirin/dialects/vmath/__init__.py b/src/kirin/dialects/vmath/__init__.py index bf6f4152a..2bcc1d113 100644 --- a/src/kirin/dialects/vmath/__init__.py +++ b/src/kirin/dialects/vmath/__init__.py @@ -6,6 +6,7 @@ from . import stmts as stmts, interp as interp from ._dialect import dialect as dialect +from .rewrites import desugar as desugar pi = pymath.pi e = pymath.e diff --git a/src/kirin/dialects/vmath/passes.py b/src/kirin/dialects/vmath/passes.py deleted file mode 100644 index 6eb08c17b..000000000 --- a/src/kirin/dialects/vmath/passes.py +++ /dev/null @@ -1,16 +0,0 @@ -from kirin import ir -from kirin.rewrite import Walk -from kirin.passes.abc import Pass -from kirin.rewrite.abc import RewriteResult - -from .rewrites.desugar import DesugarBinOp - - -class VMathDesugar(Pass): - """This pass desugars the Python list dialect - to the immutable list dialect by rewriting all - constant `list` type into `IList` type. - """ - - def unsafe_run(self, mt: ir.Method) -> RewriteResult: - return Walk(DesugarBinOp()).rewrite(mt.code) diff --git a/src/kirin/dialects/vmath/rewrites/desugar.py b/src/kirin/dialects/vmath/rewrites/desugar.py index 1cee4c7ed..0c5ad9d46 100644 --- a/src/kirin/dialects/vmath/rewrites/desugar.py +++ b/src/kirin/dialects/vmath/rewrites/desugar.py @@ -6,6 +6,7 @@ from kirin.dialects.ilist import IListType from ..stmts import add as vadd, div as vdiv, sub as vsub, mult as vmult +from .._dialect import dialect class DesugarBinOp(RewriteRule): @@ -52,6 +53,7 @@ def replace_binop(self, node: ir.Statement) -> RewriteResult: return RewriteResult() +@dialect.post_inference class WalkDesugarBinop(RewriteRule): """ Walks DesugarBinop. Needed for correct behavior when diff --git a/test/dialects/vmath/test_desugar.py b/test/dialects/vmath/test_desugar.py index 674414e09..5b74526a1 100644 --- a/test/dialects/vmath/test_desugar.py +++ b/test/dialects/vmath/test_desugar.py @@ -4,7 +4,6 @@ from kirin.prelude import basic from kirin.dialects import vmath -from kirin.dialects.vmath.passes import VMathDesugar from kirin.dialects.ilist.runtime import IList @@ -25,7 +24,6 @@ def add_scalar_lhs(): def test_add_scalar_lhs(): # out = add_scalar_lhs() - VMathDesugar(add_scalar_lhs.dialects).unsafe_run(add_scalar_lhs) add_scalar_lhs.print() res = add_scalar_lhs() assert isinstance(res, IList) @@ -34,51 +32,46 @@ def test_add_scalar_lhs(): def test_typed_kernel_add(): - VMathDesugar(add_scalar_rhs_typed.dialects).unsafe_run(add_scalar_rhs_typed) add_scalar_rhs_typed.print() res = add_scalar_rhs_typed(IList([0, 1, 2]), 3.1) assert np.allclose(np.asarray(res), np.asarray([3.1, 4.1, 5.1])) -@basic.union([vmath]) +@basic.union([vmath])(typeinfer=True) def add_two_lists(): return add_kernel(x=[0, 1, 2], y=[3, 4, 5]) def test_add_lists(): - VMathDesugar(add_two_lists.dialects).unsafe_run(add_two_lists) res = add_two_lists() assert np.allclose(np.asarray(res), np.array([0, 1, 2, 3, 4, 5])) -@basic.union([vmath]) +@basic.union([vmath])(typeinfer=True) def sub_scalar_rhs_typed(x: IList[float, Any], y: float): return x - y def test_sub_scalar_typed(): - VMathDesugar(sub_scalar_rhs_typed.dialects).unsafe_run(sub_scalar_rhs_typed) res = sub_scalar_rhs_typed(IList([0, 1, 2]), 3.1) assert np.allclose(np.asarray(res), np.asarray([-3.1, -2.1, -1.1])) -@basic.union([vmath]) +@basic.union([vmath])(typeinfer=True) def mult_scalar_lhs_typed(x: float, y: IList[float, Any]): return x * y def test_mult_scalar_typed(): - VMathDesugar(mult_scalar_lhs_typed.dialects).unsafe_run(mult_scalar_lhs_typed) res = mult_scalar_lhs_typed(3, IList([0, 1, 2])) assert np.allclose(np.asarray(res), np.asarray([0, 3, 6])) -@basic.union([vmath]) +@basic.union([vmath])(typeinfer=True) def div_scalar_lhs_typed(x: float, y: IList[float, Any]): return x / y def test_div_scalar_typed(): - VMathDesugar(div_scalar_lhs_typed.dialects).unsafe_run(div_scalar_lhs_typed) res = div_scalar_lhs_typed(3, IList([1, 1.5, 2])) assert np.allclose(np.asarray(res), np.asarray([3, 2, 1.5]))