diff --git a/src/kirin/dialects/vmath/__init__.py b/src/kirin/dialects/vmath/__init__.py index 91b3c2bc5..bf6f4152a 100644 --- a/src/kirin/dialects/vmath/__init__.py +++ b/src/kirin/dialects/vmath/__init__.py @@ -14,6 +14,13 @@ ListLen = TypeVar("ListLen") +@lowering.wraps(stmts.add) +def add( + lhs: ilist.IList[float, ListLen] | float, + rhs: ilist.IList[float, ListLen] | float, +) -> ilist.IList[float, ListLen]: ... + + @lowering.wraps(stmts.acos) def acos(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ... @@ -62,6 +69,13 @@ def cosh(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ... def degrees(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ... +@lowering.wraps(stmts.div) +def div( + lhs: ilist.IList[float, ListLen] | float, + rhs: ilist.IList[float, ListLen] | float, +) -> ilist.IList[float, ListLen]: ... + + @lowering.wraps(stmts.erf) def erf(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ... @@ -124,6 +138,13 @@ def log1p(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ... def log2(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ... +@lowering.wraps(stmts.mult) +def mult( + lhs: ilist.IList[float, ListLen] | float, + rhs: ilist.IList[float, ListLen] | float, +) -> ilist.IList[float, ListLen]: ... + + @lowering.wraps(stmts.pow) def pow(x: ilist.IList[float, ListLen], y: float) -> ilist.IList[float, ListLen]: ... @@ -150,6 +171,13 @@ def sinh(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ... def sqrt(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ... +@lowering.wraps(stmts.sub) +def sub( + lhs: ilist.IList[float, ListLen] | float, + rhs: ilist.IList[float, ListLen] | float, +) -> ilist.IList[float, ListLen]: ... + + @lowering.wraps(stmts.tan) def tan(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ... diff --git a/src/kirin/dialects/vmath/interp.py b/src/kirin/dialects/vmath/interp.py index bc7ef6208..e0114164b 100644 --- a/src/kirin/dialects/vmath/interp.py +++ b/src/kirin/dialects/vmath/interp.py @@ -12,6 +12,17 @@ @dialect.register class MathMethodTable(MethodTable): + @impl(stmts.add) + def add(self, interp, frame: Frame, stmt: stmts.add): + lhs = frame.get(stmt.lhs) + rhs = frame.get(stmt.rhs) + if isinstance(lhs, ilist.IList): + lhs = np.asarray(lhs) + if isinstance(rhs, ilist.IList): + rhs = np.asarray(rhs) + result = lhs + rhs + return (ilist.IList(result.tolist(), elem=types.Float),) + @impl(stmts.acos) def acos(self, interp, frame: Frame, stmt: stmts.acos): values = frame.get_values(stmt.args) @@ -89,6 +100,17 @@ def degrees(self, interp, frame: Frame, stmt: stmts.degrees): ilist.IList(np.degrees(np.asarray(values[0])).tolist(), elem=types.Float), ) + @impl(stmts.div) + def div(self, interp, frame: Frame, stmt: stmts.div): + lhs = frame.get(stmt.lhs) + rhs = frame.get(stmt.rhs) + if isinstance(lhs, ilist.IList): + lhs = np.asarray(lhs) + if isinstance(rhs, ilist.IList): + rhs = np.asarray(rhs) + result = lhs / rhs + return (ilist.IList(result.tolist(), elem=types.Float),) + @impl(stmts.erf) def erf(self, interp, frame: Frame, stmt: stmts.erf): values = frame.get_values(stmt.args) @@ -191,6 +213,17 @@ def log2(self, interp, frame: Frame, stmt: stmts.log2): values = frame.get_values(stmt.args) return (ilist.IList(np.log2(np.asarray(values[0])).tolist(), elem=types.Float),) + @impl(stmts.mult) + def mult(self, interp, frame: Frame, stmt: stmts.mult): + lhs = frame.get(stmt.lhs) + rhs = frame.get(stmt.rhs) + if isinstance(lhs, ilist.IList): + lhs = np.asarray(lhs) + if isinstance(rhs, ilist.IList): + rhs = np.asarray(rhs) + result = lhs * rhs + return (ilist.IList(result.tolist(), elem=types.Float),) + @impl(stmts.pow) def pow(self, interp, frame: Frame, stmt: stmts.pow): x = frame.get(stmt.x) @@ -234,6 +267,17 @@ def sqrt(self, interp, frame: Frame, stmt: stmts.sqrt): values = frame.get_values(stmt.args) return (ilist.IList(np.sqrt(np.asarray(values[0])).tolist(), elem=types.Float),) + @impl(stmts.sub) + def sub(self, interp, frame: Frame, stmt: stmts.sub): + lhs = frame.get(stmt.lhs) + rhs = frame.get(stmt.rhs) + if isinstance(lhs, ilist.IList): + lhs = np.asarray(lhs) + if isinstance(rhs, ilist.IList): + rhs = np.asarray(rhs) + result = lhs - rhs + return (ilist.IList(result.tolist(), elem=types.Float),) + @impl(stmts.tan) def tan(self, interp, frame: Frame, stmt: stmts.tan): values = frame.get_values(stmt.args) diff --git a/src/kirin/dialects/vmath/passes.py b/src/kirin/dialects/vmath/passes.py new file mode 100644 index 000000000..6eb08c17b --- /dev/null +++ b/src/kirin/dialects/vmath/passes.py @@ -0,0 +1,16 @@ +from kirin import ir +from kirin.rewrite import Walk +from kirin.passes.abc import Pass +from kirin.rewrite.abc import RewriteResult + +from .rewrites.desugar import DesugarBinOp + + +class VMathDesugar(Pass): + """This pass desugars the Python list dialect + to the immutable list dialect by rewriting all + constant `list` type into `IList` type. + """ + + def unsafe_run(self, mt: ir.Method) -> RewriteResult: + return Walk(DesugarBinOp()).rewrite(mt.code) diff --git a/src/kirin/dialects/vmath/rewrites/__init__.py b/src/kirin/dialects/vmath/rewrites/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/kirin/dialects/vmath/rewrites/desugar.py b/src/kirin/dialects/vmath/rewrites/desugar.py new file mode 100644 index 000000000..1cee4c7ed --- /dev/null +++ b/src/kirin/dialects/vmath/rewrites/desugar.py @@ -0,0 +1,62 @@ +from kirin import ir, types +from kirin.rewrite import Walk +from kirin.dialects.py import Add, Div, Sub, Mult, BinOp +from kirin.rewrite.abc import RewriteRule, RewriteResult +from kirin.ir.nodes.base import IRNode +from kirin.dialects.ilist import IListType + +from ..stmts import add as vadd, div as vdiv, sub as vsub, mult as vmult + + +class DesugarBinOp(RewriteRule): + """ + Convert py.BinOp statements with one scalar arg and one IList arg + to the corresponding vmath binop. Currently supported binops are + add, mult, sub, and div. BinOps where both args are IList are not + supported, since `+` between two IList objects is taken to mean + concatenation. + """ + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + match node: + case BinOp(): + if ( + node.lhs.type.is_subseteq(types.Number) + and node.rhs.type.is_subseteq(IListType) + ) or ( + node.lhs.type.is_subseteq(IListType) + and node.rhs.type.is_subseteq(types.Number) + ): + return self.replace_binop(node) + + case _: + return RewriteResult() + + return RewriteResult() + + def replace_binop(self, node: ir.Statement) -> RewriteResult: + match node: + case Add(): + node.replace_by(vadd(lhs=node.lhs, rhs=node.rhs)) + return RewriteResult(has_done_something=True) + case Sub(): + node.replace_by(vsub(lhs=node.lhs, rhs=node.rhs)) + return RewriteResult(has_done_something=True) + case Mult(): + node.replace_by(vmult(lhs=node.lhs, rhs=node.rhs)) + return RewriteResult(has_done_something=True) + case Div(): + node.replace_by(vdiv(lhs=node.lhs, rhs=node.rhs)) + return RewriteResult(has_done_something=True) + case _: + return RewriteResult() + + +class WalkDesugarBinop(RewriteRule): + """ + Walks DesugarBinop. Needed for correct behavior when + registering as a post-inference rewrite. + """ + + def rewrite(self, node: IRNode): + return Walk(DesugarBinOp()).rewrite(node) diff --git a/src/kirin/dialects/vmath/stmts.py b/src/kirin/dialects/vmath/stmts.py index 759753df4..cfb036395 100644 --- a/src/kirin/dialects/vmath/stmts.py +++ b/src/kirin/dialects/vmath/stmts.py @@ -7,6 +7,21 @@ ListLen = types.TypeVar("ListLen") +@statement(dialect=dialect) +class add(ir.Statement): + """Addition statement""" + + name = "add" + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) + lhs: ir.SSAValue = info.argument( + ilist.IListType[types.Float, ListLen] | types.Float + ) + rhs: ir.SSAValue = info.argument( + ilist.IListType[types.Float, ListLen] | types.Float + ) + result: ir.ResultValue = info.result(types.Any) + + @statement(dialect=dialect) class acos(ir.Statement): """acos statement, wrapping the math.acos function""" @@ -119,6 +134,21 @@ class degrees(ir.Statement): result: ir.ResultValue = info.result(ilist.IListType[types.Float, ListLen]) +@statement(dialect=dialect) +class div(ir.Statement): + """multiplication statement, scalar*list or list*list""" + + name = "div" + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) + lhs: ir.SSAValue = info.argument( + ilist.IListType[types.Float, ListLen] | types.Float + ) + rhs: ir.SSAValue = info.argument( + ilist.IListType[types.Float, ListLen] | types.Float + ) + result: ir.ResultValue = info.result(types.Any) + + @statement(dialect=dialect) class erf(ir.Statement): """erf statement, wrapping the math.erf function""" @@ -270,6 +300,21 @@ class log2(ir.Statement): result: ir.ResultValue = info.result(ilist.IListType[types.Float, ListLen]) +@statement(dialect=dialect) +class mult(ir.Statement): + """multiplication statement, scalar*list or list*list""" + + name = "mult" + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) + lhs: ir.SSAValue = info.argument( + ilist.IListType[types.Float, ListLen] | types.Float + ) + rhs: ir.SSAValue = info.argument( + ilist.IListType[types.Float, ListLen] | types.Float + ) + result: ir.ResultValue = info.result(types.Any) + + @statement(dialect=dialect) class pow(ir.Statement): """pow statement, wrapping the math.pow function""" @@ -322,6 +367,21 @@ class sinh(ir.Statement): result: ir.ResultValue = info.result(ilist.IListType[types.Float, ListLen]) +@statement(dialect=dialect) +class sub(ir.Statement): + """multiplication statement, scalar*list or list*list""" + + name = "sub" + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) + lhs: ir.SSAValue = info.argument( + ilist.IListType[types.Float, ListLen] | types.Float + ) + rhs: ir.SSAValue = info.argument( + ilist.IListType[types.Float, ListLen] | types.Float + ) + result: ir.ResultValue = info.result(types.Any) + + @statement(dialect=dialect) class sqrt(ir.Statement): """sqrt statement, wrapping the math.sqrt function""" diff --git a/test/dialects/vmath/test_basic.py b/test/dialects/vmath/test_basic.py index 0ba8b06b1..9df9306bd 100644 --- a/test/dialects/vmath/test_basic.py +++ b/test/dialects/vmath/test_basic.py @@ -6,6 +6,132 @@ from kirin.dialects import ilist, vmath +@basic.union([vmath]) +def add_kernel(x, y): + return vmath.add(x, y) + + +def test_add_lists(): + a = ilist.IList([0.0, 1.0, 2.0], elem=types.Float) + b = ilist.IList([3.0, 4.0, 5.0], elem=types.Float) + truth = np.array([0.0, 1.0, 2.0]) + np.array([3.0, 4.0, 5.0]) + out = add_kernel(a, b) + assert isinstance(out, ilist.IList) + assert out.elem == types.Float + assert np.allclose(out, truth) + + +@basic.union([vmath]) +def sub_kernel(x, y): + return vmath.sub(x, y) + + +def test_sub_lists(): + a = ilist.IList([5.0, 7.0, 9.0], elem=types.Float) + b = ilist.IList([3.0, 4.0, 5.0], elem=types.Float) + truth = np.array([5.0, 7.0, 9.0]) - np.array([3.0, 4.0, 5.0]) + out = sub_kernel(a, b) + assert isinstance(out, ilist.IList) + assert out.elem == types.Float + assert np.allclose(out, truth) + + +def test_sub_scalar_list(): + a = 10.0 + b = ilist.IList([3.0, 4.0, 5.0], elem=types.Float) + truth = 10.0 - np.array([3.0, 4.0, 5.0]) + out = sub_kernel(a, b) + out2 = sub_kernel(b, a) + + assert isinstance(out, ilist.IList) + assert out.elem == types.Float + assert np.allclose(out, truth) + + truth2 = np.array([3.0, 4.0, 5.0]) - 10.0 + assert isinstance(out2, ilist.IList) + assert out2.elem == types.Float + assert np.allclose(out2, truth2) + + +def test_add_scalar_list(): + a = 2.0 + b = ilist.IList([3.0, 4.0, 5.0], elem=types.Float) + truth = 2.0 + np.array([3.0, 4.0, 5.0]) + out = add_kernel(a, b) + out2 = add_kernel(b, a) + + assert isinstance(out, ilist.IList) + assert out.elem == types.Float + assert np.allclose(out, truth) + + assert isinstance(out2, ilist.IList) + assert out2.elem == types.Float + assert np.allclose(out2, truth) + + +@basic.union([vmath]) +def mult_kernel(x, y): + return vmath.mult(x, y) + + +def test_mult_lists(): + a = ilist.IList([1.0, 2.0, 3.0], elem=types.Float) + b = ilist.IList([4.0, 5.0, 6.0], elem=types.Float) + truth = np.array([1.0, 2.0, 3.0]) * np.array([4.0, 5.0, 6.0]) + out = mult_kernel(a, b) + assert isinstance(out, ilist.IList) + assert out.elem == types.Float + assert np.allclose(out, truth) + + +def test_mult_scalar_list(): + a = 3.0 + b = ilist.IList([4.0, 5.0, 6.0], elem=types.Float) + truth = 3.0 * np.array([4.0, 5.0, 6.0]) + out = mult_kernel(a, b) + out2 = mult_kernel(b, a) + + assert isinstance(out, ilist.IList) + assert out.elem == types.Float + assert np.allclose(out, truth) + + assert isinstance(out2, ilist.IList) + assert out2.elem == types.Float + assert np.allclose(out2, truth) + + +@basic.union([vmath]) +def div_kernel(x, y): + return vmath.div(x, y) + + +def test_div_lists(): + a = ilist.IList([8.0, 9.0, 10.0], elem=types.Float) + b = ilist.IList([2.0, 3.0, 5.2], elem=types.Float) + truth = np.array([8.0, 9.0, 10.0]) / np.array([2.0, 3.0, 5.2]) + out = div_kernel(a, b) + assert isinstance(out, ilist.IList) + assert out.elem == types.Float + assert np.allclose(out, truth) + + +def test_div_scalar_list(): + a = 12.0 + b = ilist.IList([2.0, 3.0, 4.0], elem=types.Float) + truth = 12.0 / np.array([2.0, 3.0, 4.0]) + out = div_kernel(a, b) + out2 = div_kernel(b, a) + + assert isinstance(out, ilist.IList) + assert out.elem == types.Float + assert np.allclose(out, truth) + + truth2 = np.array([2.0, 3.0, 4.0]) / 12.0 + assert isinstance(out2, ilist.IList) + assert out2.elem == types.Float + assert np.allclose(out2, truth2) + + @basic.union([vmath]) def acos_func(x): return vmath.acos(x) diff --git a/test/dialects/vmath/test_desugar.py b/test/dialects/vmath/test_desugar.py new file mode 100644 index 000000000..674414e09 --- /dev/null +++ b/test/dialects/vmath/test_desugar.py @@ -0,0 +1,84 @@ +from typing import Any + +import numpy as np + +from kirin.prelude import basic +from kirin.dialects import vmath +from kirin.dialects.vmath.passes import VMathDesugar +from kirin.dialects.ilist.runtime import IList + + +@basic.union([vmath]) +def add_kernel(x, y): + return x + y + + +@basic.union([vmath])(typeinfer=True) +def add_scalar_rhs_typed(x: IList[float, Any], y: float): + return x + y + + +@basic.union([vmath])(aggressive=True, typeinfer=True) +def add_scalar_lhs(): + return add_kernel(x=3.0, y=[3.0, 4, 5]) + + +def test_add_scalar_lhs(): + # out = add_scalar_lhs() + VMathDesugar(add_scalar_lhs.dialects).unsafe_run(add_scalar_lhs) + add_scalar_lhs.print() + res = add_scalar_lhs() + assert isinstance(res, IList) + assert res.type.vars[0].typ is float + assert np.allclose(np.asarray(res), np.array([6, 7, 8])) + + +def test_typed_kernel_add(): + VMathDesugar(add_scalar_rhs_typed.dialects).unsafe_run(add_scalar_rhs_typed) + add_scalar_rhs_typed.print() + res = add_scalar_rhs_typed(IList([0, 1, 2]), 3.1) + assert np.allclose(np.asarray(res), np.asarray([3.1, 4.1, 5.1])) + + +@basic.union([vmath]) +def add_two_lists(): + return add_kernel(x=[0, 1, 2], y=[3, 4, 5]) + + +def test_add_lists(): + VMathDesugar(add_two_lists.dialects).unsafe_run(add_two_lists) + res = add_two_lists() + assert np.allclose(np.asarray(res), np.array([0, 1, 2, 3, 4, 5])) + + +@basic.union([vmath]) +def sub_scalar_rhs_typed(x: IList[float, Any], y: float): + return x - y + + +def test_sub_scalar_typed(): + VMathDesugar(sub_scalar_rhs_typed.dialects).unsafe_run(sub_scalar_rhs_typed) + res = sub_scalar_rhs_typed(IList([0, 1, 2]), 3.1) + assert np.allclose(np.asarray(res), np.asarray([-3.1, -2.1, -1.1])) + + +@basic.union([vmath]) +def mult_scalar_lhs_typed(x: float, y: IList[float, Any]): + return x * y + + +def test_mult_scalar_typed(): + VMathDesugar(mult_scalar_lhs_typed.dialects).unsafe_run(mult_scalar_lhs_typed) + res = mult_scalar_lhs_typed(3, IList([0, 1, 2])) + assert np.allclose(np.asarray(res), np.asarray([0, 3, 6])) + + +@basic.union([vmath]) +def div_scalar_lhs_typed(x: float, y: IList[float, Any]): + return x / y + + +def test_div_scalar_typed(): + VMathDesugar(div_scalar_lhs_typed.dialects).unsafe_run(div_scalar_lhs_typed) + res = div_scalar_lhs_typed(3, IList([1, 1.5, 2])) + assert np.allclose(np.asarray(res), np.asarray([3, 2, 1.5]))