Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/kirin/dialects/vmath/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from . import stmts as stmts, interp as interp
from ._dialect import dialect as dialect
from .rewrites import desugar as desugar

pi = pymath.pi
e = pymath.e
Expand Down
16 changes: 0 additions & 16 deletions src/kirin/dialects/vmath/passes.py

This file was deleted.

2 changes: 2 additions & 0 deletions src/kirin/dialects/vmath/rewrites/desugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from kirin.dialects.ilist import IListType

from ..stmts import add as vadd, div as vdiv, sub as vsub, mult as vmult
from .._dialect import dialect


class DesugarBinOp(RewriteRule):
Expand Down Expand Up @@ -52,6 +53,7 @@ def replace_binop(self, node: ir.Statement) -> RewriteResult:
return RewriteResult()


@dialect.post_inference
class WalkDesugarBinop(RewriteRule):
Comment on lines +56 to 57
Copy link

Copilot AI Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The @dialect.post_inference decorator on WalkDesugarBinop will not be executed because the rewrites module is not imported in src/kirin/dialects/vmath/__init__.py.

For the decorator to register the rewrite rule with the dialect, the module containing the decorated class must be imported when the vmath dialect is loaded. This is the pattern used in other dialects like ilist, which imports rewrite in its __init__.py.

To fix this, add from . import rewrites as rewrites to src/kirin/dialects/vmath/__init__.py, similar to how ilist imports its rewrite module. Without this import, the tests will likely fail because the desugaring transformation won't be applied automatically.

Copilot uses AI. Check for mistakes.
"""
Walks DesugarBinop. Needed for correct behavior when
Expand Down
15 changes: 4 additions & 11 deletions test/dialects/vmath/test_desugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from kirin.prelude import basic
from kirin.dialects import vmath
from kirin.dialects.vmath.passes import VMathDesugar
from kirin.dialects.ilist.runtime import IList


Expand All @@ -25,7 +24,6 @@ def add_scalar_lhs():

def test_add_scalar_lhs():
# out = add_scalar_lhs()
VMathDesugar(add_scalar_lhs.dialects).unsafe_run(add_scalar_lhs)
add_scalar_lhs.print()
res = add_scalar_lhs()
assert isinstance(res, IList)
Expand All @@ -34,51 +32,46 @@ def test_add_scalar_lhs():


def test_typed_kernel_add():
VMathDesugar(add_scalar_rhs_typed.dialects).unsafe_run(add_scalar_rhs_typed)
add_scalar_rhs_typed.print()
res = add_scalar_rhs_typed(IList([0, 1, 2]), 3.1)
assert np.allclose(np.asarray(res), np.asarray([3.1, 4.1, 5.1]))


@basic.union([vmath])
@basic.union([vmath])(typeinfer=True)
def add_two_lists():
return add_kernel(x=[0, 1, 2], y=[3, 4, 5])


def test_add_lists():
VMathDesugar(add_two_lists.dialects).unsafe_run(add_two_lists)
res = add_two_lists()
assert np.allclose(np.asarray(res), np.array([0, 1, 2, 3, 4, 5]))


@basic.union([vmath])
@basic.union([vmath])(typeinfer=True)
def sub_scalar_rhs_typed(x: IList[float, Any], y: float):
return x - y


def test_sub_scalar_typed():
VMathDesugar(sub_scalar_rhs_typed.dialects).unsafe_run(sub_scalar_rhs_typed)
res = sub_scalar_rhs_typed(IList([0, 1, 2]), 3.1)
assert np.allclose(np.asarray(res), np.asarray([-3.1, -2.1, -1.1]))


@basic.union([vmath])
@basic.union([vmath])(typeinfer=True)
def mult_scalar_lhs_typed(x: float, y: IList[float, Any]):
return x * y


def test_mult_scalar_typed():
VMathDesugar(mult_scalar_lhs_typed.dialects).unsafe_run(mult_scalar_lhs_typed)
res = mult_scalar_lhs_typed(3, IList([0, 1, 2]))
assert np.allclose(np.asarray(res), np.asarray([0, 3, 6]))


@basic.union([vmath])
@basic.union([vmath])(typeinfer=True)
def div_scalar_lhs_typed(x: float, y: IList[float, Any]):
return x / y


def test_div_scalar_typed():
VMathDesugar(div_scalar_lhs_typed.dialects).unsafe_run(div_scalar_lhs_typed)
res = div_scalar_lhs_typed(3, IList([1, 1.5, 2]))
assert np.allclose(np.asarray(res), np.asarray([3, 2, 1.5]))