Skip to content
28 changes: 28 additions & 0 deletions src/kirin/dialects/vmath/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
ListLen = TypeVar("ListLen")


@lowering.wraps(stmts.add)
def add(
lhs: ilist.IList[float, ListLen] | float,
rhs: ilist.IList[float, ListLen] | float,
) -> ilist.IList[float, ListLen]: ...


@lowering.wraps(stmts.acos)
def acos(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...

Expand Down Expand Up @@ -62,6 +69,13 @@ def cosh(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
def degrees(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...


@lowering.wraps(stmts.div)
def div(
lhs: ilist.IList[float, ListLen] | float,
rhs: ilist.IList[float, ListLen] | float,
) -> ilist.IList[float, ListLen]: ...


@lowering.wraps(stmts.erf)
def erf(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...

Expand Down Expand Up @@ -124,6 +138,13 @@ def log1p(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
def log2(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...


@lowering.wraps(stmts.mult)
def mult(
lhs: ilist.IList[float, ListLen] | float,
rhs: ilist.IList[float, ListLen] | float,
) -> ilist.IList[float, ListLen]: ...


@lowering.wraps(stmts.pow)
def pow(x: ilist.IList[float, ListLen], y: float) -> ilist.IList[float, ListLen]: ...

Expand All @@ -150,6 +171,13 @@ def sinh(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
def sqrt(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...


@lowering.wraps(stmts.sub)
def sub(
lhs: ilist.IList[float, ListLen] | float,
rhs: ilist.IList[float, ListLen] | float,
) -> ilist.IList[float, ListLen]: ...


@lowering.wraps(stmts.tan)
def tan(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...

Expand Down
44 changes: 44 additions & 0 deletions src/kirin/dialects/vmath/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@
@dialect.register
class MathMethodTable(MethodTable):

@impl(stmts.add)
def add(self, interp, frame: Frame, stmt: stmts.add):
lhs = frame.get(stmt.lhs)
rhs = frame.get(stmt.rhs)
if isinstance(lhs, ilist.IList):
lhs = np.asarray(lhs)
if isinstance(rhs, ilist.IList):
rhs = np.asarray(rhs)
result = lhs + rhs
return (ilist.IList(result.tolist(), elem=types.Float),)

@impl(stmts.acos)
def acos(self, interp, frame: Frame, stmt: stmts.acos):
values = frame.get_values(stmt.args)
Expand Down Expand Up @@ -89,6 +100,17 @@ def degrees(self, interp, frame: Frame, stmt: stmts.degrees):
ilist.IList(np.degrees(np.asarray(values[0])).tolist(), elem=types.Float),
)

@impl(stmts.div)
def div(self, interp, frame: Frame, stmt: stmts.div):
lhs = frame.get(stmt.lhs)
rhs = frame.get(stmt.rhs)
if isinstance(lhs, ilist.IList):
lhs = np.asarray(lhs)
if isinstance(rhs, ilist.IList):
rhs = np.asarray(rhs)
result = lhs / rhs
return (ilist.IList(result.tolist(), elem=types.Float),)

@impl(stmts.erf)
def erf(self, interp, frame: Frame, stmt: stmts.erf):
values = frame.get_values(stmt.args)
Expand Down Expand Up @@ -191,6 +213,17 @@ def log2(self, interp, frame: Frame, stmt: stmts.log2):
values = frame.get_values(stmt.args)
return (ilist.IList(np.log2(np.asarray(values[0])).tolist(), elem=types.Float),)

@impl(stmts.mult)
def mult(self, interp, frame: Frame, stmt: stmts.mult):
lhs = frame.get(stmt.lhs)
rhs = frame.get(stmt.rhs)
if isinstance(lhs, ilist.IList):
lhs = np.asarray(lhs)
if isinstance(rhs, ilist.IList):
rhs = np.asarray(rhs)
result = lhs * rhs
return (ilist.IList(result.tolist(), elem=types.Float),)

@impl(stmts.pow)
def pow(self, interp, frame: Frame, stmt: stmts.pow):
x = frame.get(stmt.x)
Expand Down Expand Up @@ -234,6 +267,17 @@ def sqrt(self, interp, frame: Frame, stmt: stmts.sqrt):
values = frame.get_values(stmt.args)
return (ilist.IList(np.sqrt(np.asarray(values[0])).tolist(), elem=types.Float),)

@impl(stmts.sub)
def sub(self, interp, frame: Frame, stmt: stmts.sub):
lhs = frame.get(stmt.lhs)
rhs = frame.get(stmt.rhs)
if isinstance(lhs, ilist.IList):
lhs = np.asarray(lhs)
if isinstance(rhs, ilist.IList):
rhs = np.asarray(rhs)
result = lhs - rhs
return (ilist.IList(result.tolist(), elem=types.Float),)

@impl(stmts.tan)
def tan(self, interp, frame: Frame, stmt: stmts.tan):
values = frame.get_values(stmt.args)
Expand Down
16 changes: 16 additions & 0 deletions src/kirin/dialects/vmath/passes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from kirin import ir
from kirin.rewrite import Walk
from kirin.passes.abc import Pass
from kirin.rewrite.abc import RewriteResult

from .rewrites.desugar import DesugarBinOp


class VMathDesugar(Pass):
"""This pass desugars the Python list dialect
to the immutable list dialect by rewriting all
constant `list` type into `IList` type.
"""

def unsafe_run(self, mt: ir.Method) -> RewriteResult:
return Walk(DesugarBinOp()).rewrite(mt.code)
Empty file.
62 changes: 62 additions & 0 deletions src/kirin/dialects/vmath/rewrites/desugar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from kirin import ir, types
from kirin.rewrite import Walk
from kirin.dialects.py import Add, Div, Sub, Mult, BinOp
from kirin.rewrite.abc import RewriteRule, RewriteResult
from kirin.ir.nodes.base import IRNode
from kirin.dialects.ilist import IListType

from ..stmts import add as vadd, div as vdiv, sub as vsub, mult as vmult


class DesugarBinOp(RewriteRule):
"""
Convert py.BinOp statements with one scalar arg and one IList arg
to the corresponding vmath binop. Currently supported binops are
add, mult, sub, and div. BinOps where both args are IList are not
supported, since `+` between two IList objects is taken to mean
concatenation.
"""

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
match node:
case BinOp():
if (
node.lhs.type.is_subseteq(types.Number)
and node.rhs.type.is_subseteq(IListType)
) or (
node.lhs.type.is_subseteq(IListType)
and node.rhs.type.is_subseteq(types.Number)
):
return self.replace_binop(node)

case _:
return RewriteResult()

return RewriteResult()

def replace_binop(self, node: ir.Statement) -> RewriteResult:
match node:
case Add():
node.replace_by(vadd(lhs=node.lhs, rhs=node.rhs))
return RewriteResult(has_done_something=True)
case Sub():
node.replace_by(vsub(lhs=node.lhs, rhs=node.rhs))
return RewriteResult(has_done_something=True)
case Mult():
node.replace_by(vmult(lhs=node.lhs, rhs=node.rhs))
return RewriteResult(has_done_something=True)
case Div():
node.replace_by(vdiv(lhs=node.lhs, rhs=node.rhs))
return RewriteResult(has_done_something=True)
case _:
return RewriteResult()


class WalkDesugarBinop(RewriteRule):
"""
Walks DesugarBinop. Needed for correct behavior when
registering as a post-inference rewrite.
"""

def rewrite(self, node: IRNode):
return Walk(DesugarBinOp()).rewrite(node)
60 changes: 60 additions & 0 deletions src/kirin/dialects/vmath/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,21 @@
ListLen = types.TypeVar("ListLen")


@statement(dialect=dialect)
class add(ir.Statement):
"""Addition statement"""

name = "add"
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
lhs: ir.SSAValue = info.argument(
ilist.IListType[types.Float, ListLen] | types.Float
)
rhs: ir.SSAValue = info.argument(
ilist.IListType[types.Float, ListLen] | types.Float
)
result: ir.ResultValue = info.result(types.Any)


@statement(dialect=dialect)
class acos(ir.Statement):
"""acos statement, wrapping the math.acos function"""
Expand Down Expand Up @@ -119,6 +134,21 @@ class degrees(ir.Statement):
result: ir.ResultValue = info.result(ilist.IListType[types.Float, ListLen])


@statement(dialect=dialect)
class div(ir.Statement):
"""multiplication statement, scalar*list or list*list"""

name = "div"
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
lhs: ir.SSAValue = info.argument(
ilist.IListType[types.Float, ListLen] | types.Float
)
rhs: ir.SSAValue = info.argument(
ilist.IListType[types.Float, ListLen] | types.Float
)
result: ir.ResultValue = info.result(types.Any)


@statement(dialect=dialect)
class erf(ir.Statement):
"""erf statement, wrapping the math.erf function"""
Expand Down Expand Up @@ -270,6 +300,21 @@ class log2(ir.Statement):
result: ir.ResultValue = info.result(ilist.IListType[types.Float, ListLen])


@statement(dialect=dialect)
class mult(ir.Statement):
"""multiplication statement, scalar*list or list*list"""

name = "mult"
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
lhs: ir.SSAValue = info.argument(
ilist.IListType[types.Float, ListLen] | types.Float
)
rhs: ir.SSAValue = info.argument(
ilist.IListType[types.Float, ListLen] | types.Float
)
result: ir.ResultValue = info.result(types.Any)


@statement(dialect=dialect)
class pow(ir.Statement):
"""pow statement, wrapping the math.pow function"""
Expand Down Expand Up @@ -322,6 +367,21 @@ class sinh(ir.Statement):
result: ir.ResultValue = info.result(ilist.IListType[types.Float, ListLen])


@statement(dialect=dialect)
class sub(ir.Statement):
"""multiplication statement, scalar*list or list*list"""

name = "sub"
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
lhs: ir.SSAValue = info.argument(
ilist.IListType[types.Float, ListLen] | types.Float
)
rhs: ir.SSAValue = info.argument(
ilist.IListType[types.Float, ListLen] | types.Float
)
result: ir.ResultValue = info.result(types.Any)


@statement(dialect=dialect)
class sqrt(ir.Statement):
"""sqrt statement, wrapping the math.sqrt function"""
Expand Down
Loading