Skip to content

Commit 39d8f31

Browse files
committed
add vmath.add
1 parent 300e4f2 commit 39d8f31

File tree

4 files changed

+58
-0
lines changed

4 files changed

+58
-0
lines changed

src/kirin/dialects/vmath/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414
ListLen = TypeVar("ListLen")
1515

1616

17+
@lowering.wraps(stmts.add)
18+
def add(
19+
lhs: ilist.IList[float, ListLen] | float,
20+
rhs: ilist.IList[float, ListLen] | float,
21+
) -> ilist.IList[float, ListLen]: ...
22+
23+
1724
@lowering.wraps(stmts.acos)
1825
def acos(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
1926

src/kirin/dialects/vmath/interp.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,17 @@
1212
@dialect.register
1313
class MathMethodTable(MethodTable):
1414

15+
@impl(stmts.add)
16+
def add(self, interp, frame: Frame, stmt: stmts.add):
17+
lhs = frame.get(stmt.lhs)
18+
rhs = frame.get(stmt.rhs)
19+
if isinstance(lhs, ilist.IList):
20+
lhs = np.asarray(lhs)
21+
if isinstance(rhs, ilist.IList):
22+
rhs = np.asarray(rhs)
23+
arraysum = lhs + rhs
24+
return (ilist.IList(arraysum.tolist(), elem=types.Float),)
25+
1526
@impl(stmts.acos)
1627
def acos(self, interp, frame: Frame, stmt: stmts.acos):
1728
values = frame.get_values(stmt.args)

src/kirin/dialects/vmath/stmts.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,21 @@
77
ListLen = types.TypeVar("ListLen")
88

99

10+
@statement(dialect=dialect)
11+
class add(ir.Statement):
12+
"""Addition statement"""
13+
14+
name = "Add"
15+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
16+
lhs: ir.SSAValue = info.argument(
17+
ilist.IListType[types.Float, ListLen] | types.Float
18+
)
19+
rhs: ir.SSAValue = info.argument(
20+
ilist.IListType[types.Float, ListLen] | types.Float
21+
)
22+
result: ir.ResultValue = info.result(types.Any)
23+
24+
1025
@statement(dialect=dialect)
1126
class acos(ir.Statement):
1227
"""acos statement, wrapping the math.acos function"""

test/dialects/vmath/test_basic.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,31 @@
66
from kirin.dialects import ilist, vmath
77

88

9+
@basic.union([vmath])
10+
def add_kernel(x, y):
11+
return vmath.add(x, y)
12+
13+
14+
def test_add_lists():
15+
a = ilist.IList([0.0, 1.0, 2.0], elem=types.Float)
16+
b = ilist.IList([3.0, 4.0, 5.0], elem=types.Float)
17+
truth = np.array([0.0, 1.0, 2.0]) + np.array([3.0, 4.0, 5.0])
18+
out = add_kernel(a, b)
19+
assert isinstance(out, ilist.IList)
20+
assert out.elem == types.Float
21+
assert np.allclose(out, truth)
22+
23+
24+
def test_add_scalar_list():
25+
a = 2.0
26+
b = ilist.IList([3.0, 4.0, 5.0], elem=types.Float)
27+
truth = 2.0 + np.array([3.0, 4.0, 5.0])
28+
out = add_kernel(a, b)
29+
assert isinstance(out, ilist.IList)
30+
assert out.elem == types.Float
31+
assert np.allclose(out, truth)
32+
33+
934
@basic.union([vmath])
1035
def acos_func(x):
1136
return vmath.acos(x)

0 commit comments

Comments
 (0)