Skip to content

Commit acdd5d3

Browse files
committed
add sub, div, mul to vmath
1 parent 39d8f31 commit acdd5d3

File tree

4 files changed

+203
-3
lines changed

4 files changed

+203
-3
lines changed

src/kirin/dialects/vmath/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,13 @@ def cosh(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
6969
def degrees(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
7070

7171

72+
@lowering.wraps(stmts.div)
73+
def div(
74+
lhs: ilist.IList[float, ListLen] | float,
75+
rhs: ilist.IList[float, ListLen] | float,
76+
) -> ilist.IList[float, ListLen]: ...
77+
78+
7279
@lowering.wraps(stmts.erf)
7380
def erf(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
7481

@@ -131,6 +138,13 @@ def log1p(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
131138
def log2(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
132139

133140

141+
@lowering.wraps(stmts.mul)
142+
def mul(
143+
lhs: ilist.IList[float, ListLen] | float,
144+
rhs: ilist.IList[float, ListLen] | float,
145+
) -> ilist.IList[float, ListLen]: ...
146+
147+
134148
@lowering.wraps(stmts.pow)
135149
def pow(x: ilist.IList[float, ListLen], y: float) -> ilist.IList[float, ListLen]: ...
136150

@@ -157,6 +171,13 @@ def sinh(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
157171
def sqrt(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
158172

159173

174+
@lowering.wraps(stmts.sub)
175+
def sub(
176+
lhs: ilist.IList[float, ListLen] | float,
177+
rhs: ilist.IList[float, ListLen] | float,
178+
) -> ilist.IList[float, ListLen]: ...
179+
180+
160181
@lowering.wraps(stmts.tan)
161182
def tan(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
162183

src/kirin/dialects/vmath/interp.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def add(self, interp, frame: Frame, stmt: stmts.add):
2020
lhs = np.asarray(lhs)
2121
if isinstance(rhs, ilist.IList):
2222
rhs = np.asarray(rhs)
23-
arraysum = lhs + rhs
24-
return (ilist.IList(arraysum.tolist(), elem=types.Float),)
23+
result = lhs + rhs
24+
return (ilist.IList(result.tolist(), elem=types.Float),)
2525

2626
@impl(stmts.acos)
2727
def acos(self, interp, frame: Frame, stmt: stmts.acos):
@@ -100,6 +100,17 @@ def degrees(self, interp, frame: Frame, stmt: stmts.degrees):
100100
ilist.IList(np.degrees(np.asarray(values[0])).tolist(), elem=types.Float),
101101
)
102102

103+
@impl(stmts.div)
104+
def div(self, interp, frame: Frame, stmt: stmts.div):
105+
lhs = frame.get(stmt.lhs)
106+
rhs = frame.get(stmt.rhs)
107+
if isinstance(lhs, ilist.IList):
108+
lhs = np.asarray(lhs)
109+
if isinstance(rhs, ilist.IList):
110+
rhs = np.asarray(rhs)
111+
result = lhs / rhs
112+
return (ilist.IList(result.tolist(), elem=types.Float),)
113+
103114
@impl(stmts.erf)
104115
def erf(self, interp, frame: Frame, stmt: stmts.erf):
105116
values = frame.get_values(stmt.args)
@@ -202,6 +213,17 @@ def log2(self, interp, frame: Frame, stmt: stmts.log2):
202213
values = frame.get_values(stmt.args)
203214
return (ilist.IList(np.log2(np.asarray(values[0])).tolist(), elem=types.Float),)
204215

216+
@impl(stmts.mul)
217+
def mul(self, interp, frame: Frame, stmt: stmts.mul):
218+
lhs = frame.get(stmt.lhs)
219+
rhs = frame.get(stmt.rhs)
220+
if isinstance(lhs, ilist.IList):
221+
lhs = np.asarray(lhs)
222+
if isinstance(rhs, ilist.IList):
223+
rhs = np.asarray(rhs)
224+
result = lhs * rhs
225+
return (ilist.IList(result.tolist(), elem=types.Float),)
226+
205227
@impl(stmts.pow)
206228
def pow(self, interp, frame: Frame, stmt: stmts.pow):
207229
x = frame.get(stmt.x)
@@ -245,6 +267,17 @@ def sqrt(self, interp, frame: Frame, stmt: stmts.sqrt):
245267
values = frame.get_values(stmt.args)
246268
return (ilist.IList(np.sqrt(np.asarray(values[0])).tolist(), elem=types.Float),)
247269

270+
@impl(stmts.sub)
271+
def sub(self, interp, frame: Frame, stmt: stmts.sub):
272+
lhs = frame.get(stmt.lhs)
273+
rhs = frame.get(stmt.rhs)
274+
if isinstance(lhs, ilist.IList):
275+
lhs = np.asarray(lhs)
276+
if isinstance(rhs, ilist.IList):
277+
rhs = np.asarray(rhs)
278+
result = lhs - rhs
279+
return (ilist.IList(result.tolist(), elem=types.Float),)
280+
248281
@impl(stmts.tan)
249282
def tan(self, interp, frame: Frame, stmt: stmts.tan):
250283
values = frame.get_values(stmt.args)

src/kirin/dialects/vmath/stmts.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
class add(ir.Statement):
1212
"""Addition statement"""
1313

14-
name = "Add"
14+
name = "add"
1515
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
1616
lhs: ir.SSAValue = info.argument(
1717
ilist.IListType[types.Float, ListLen] | types.Float
@@ -134,6 +134,21 @@ class degrees(ir.Statement):
134134
result: ir.ResultValue = info.result(ilist.IListType[types.Float, ListLen])
135135

136136

137+
@statement(dialect=dialect)
138+
class div(ir.Statement):
139+
"""multiplication statement, scalar*list or list*list"""
140+
141+
name = "div"
142+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
143+
lhs: ir.SSAValue = info.argument(
144+
ilist.IListType[types.Float, ListLen] | types.Float
145+
)
146+
rhs: ir.SSAValue = info.argument(
147+
ilist.IListType[types.Float, ListLen] | types.Float
148+
)
149+
result: ir.ResultValue = info.result(types.Any)
150+
151+
137152
@statement(dialect=dialect)
138153
class erf(ir.Statement):
139154
"""erf statement, wrapping the math.erf function"""
@@ -285,6 +300,21 @@ class log2(ir.Statement):
285300
result: ir.ResultValue = info.result(ilist.IListType[types.Float, ListLen])
286301

287302

303+
@statement(dialect=dialect)
304+
class mul(ir.Statement):
305+
"""multiplication statement, scalar*list or list*list"""
306+
307+
name = "mul"
308+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
309+
lhs: ir.SSAValue = info.argument(
310+
ilist.IListType[types.Float, ListLen] | types.Float
311+
)
312+
rhs: ir.SSAValue = info.argument(
313+
ilist.IListType[types.Float, ListLen] | types.Float
314+
)
315+
result: ir.ResultValue = info.result(types.Any)
316+
317+
288318
@statement(dialect=dialect)
289319
class pow(ir.Statement):
290320
"""pow statement, wrapping the math.pow function"""
@@ -337,6 +367,21 @@ class sinh(ir.Statement):
337367
result: ir.ResultValue = info.result(ilist.IListType[types.Float, ListLen])
338368

339369

370+
@statement(dialect=dialect)
371+
class sub(ir.Statement):
372+
"""multiplication statement, scalar*list or list*list"""
373+
374+
name = "sub"
375+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
376+
lhs: ir.SSAValue = info.argument(
377+
ilist.IListType[types.Float, ListLen] | types.Float
378+
)
379+
rhs: ir.SSAValue = info.argument(
380+
ilist.IListType[types.Float, ListLen] | types.Float
381+
)
382+
result: ir.ResultValue = info.result(types.Any)
383+
384+
340385
@statement(dialect=dialect)
341386
class sqrt(ir.Statement):
342387
"""sqrt statement, wrapping the math.sqrt function"""

test/dialects/vmath/test_basic.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,116 @@ def test_add_lists():
2121
assert np.allclose(out, truth)
2222

2323

24+
@basic.union([vmath])
25+
def sub_kernel(x, y):
26+
return vmath.sub(x, y)
27+
28+
29+
def test_sub_lists():
30+
a = ilist.IList([5.0, 7.0, 9.0], elem=types.Float)
31+
b = ilist.IList([3.0, 4.0, 5.0], elem=types.Float)
32+
truth = np.array([5.0, 7.0, 9.0]) - np.array([3.0, 4.0, 5.0])
33+
out = sub_kernel(a, b)
34+
assert isinstance(out, ilist.IList)
35+
assert out.elem == types.Float
36+
assert np.allclose(out, truth)
37+
38+
39+
def test_sub_scalar_list():
40+
a = 10.0
41+
b = ilist.IList([3.0, 4.0, 5.0], elem=types.Float)
42+
truth = 10.0 - np.array([3.0, 4.0, 5.0])
43+
out = sub_kernel(a, b)
44+
out2 = sub_kernel(b, a)
45+
46+
assert isinstance(out, ilist.IList)
47+
assert out.elem == types.Float
48+
assert np.allclose(out, truth)
49+
50+
truth2 = np.array([3.0, 4.0, 5.0]) - 10.0
51+
assert isinstance(out2, ilist.IList)
52+
assert out2.elem == types.Float
53+
assert np.allclose(out2, truth2)
54+
55+
2456
def test_add_scalar_list():
2557
a = 2.0
2658
b = ilist.IList([3.0, 4.0, 5.0], elem=types.Float)
2759
truth = 2.0 + np.array([3.0, 4.0, 5.0])
2860
out = add_kernel(a, b)
61+
out2 = add_kernel(b, a)
62+
2963
assert isinstance(out, ilist.IList)
3064
assert out.elem == types.Float
3165
assert np.allclose(out, truth)
3266

67+
assert isinstance(out2, ilist.IList)
68+
assert out2.elem == types.Float
69+
assert np.allclose(out2, truth)
70+
71+
72+
@basic.union([vmath])
73+
def mul_kernel(x, y):
74+
return vmath.mul(x, y)
75+
76+
77+
def test_mul_lists():
78+
a = ilist.IList([1.0, 2.0, 3.0], elem=types.Float)
79+
b = ilist.IList([4.0, 5.0, 6.0], elem=types.Float)
80+
truth = np.array([1.0, 2.0, 3.0]) * np.array([4.0, 5.0, 6.0])
81+
out = mul_kernel(a, b)
82+
assert isinstance(out, ilist.IList)
83+
assert out.elem == types.Float
84+
assert np.allclose(out, truth)
85+
86+
87+
def test_mul_scalar_list():
88+
a = 3.0
89+
b = ilist.IList([4.0, 5.0, 6.0], elem=types.Float)
90+
truth = 3.0 * np.array([4.0, 5.0, 6.0])
91+
out = mul_kernel(a, b)
92+
out2 = mul_kernel(b, a)
93+
94+
assert isinstance(out, ilist.IList)
95+
assert out.elem == types.Float
96+
assert np.allclose(out, truth)
97+
98+
assert isinstance(out2, ilist.IList)
99+
assert out2.elem == types.Float
100+
assert np.allclose(out2, truth)
101+
102+
103+
@basic.union([vmath])
104+
def div_kernel(x, y):
105+
return vmath.div(x, y)
106+
107+
108+
def test_div_lists():
109+
a = ilist.IList([8.0, 9.0, 10.0], elem=types.Float)
110+
b = ilist.IList([2.0, 3.0, 5.2], elem=types.Float)
111+
truth = np.array([8.0, 9.0, 10.0]) / np.array([2.0, 3.0, 5.2])
112+
out = div_kernel(a, b)
113+
assert isinstance(out, ilist.IList)
114+
assert out.elem == types.Float
115+
assert np.allclose(out, truth)
116+
117+
118+
def test_div_scalar_list():
119+
a = 12.0
120+
b = ilist.IList([2.0, 3.0, 4.0], elem=types.Float)
121+
truth = 12.0 / np.array([2.0, 3.0, 4.0])
122+
out = div_kernel(a, b)
123+
out2 = div_kernel(b, a)
124+
125+
assert isinstance(out, ilist.IList)
126+
assert out.elem == types.Float
127+
assert np.allclose(out, truth)
128+
129+
truth2 = np.array([2.0, 3.0, 4.0]) / 12.0
130+
assert isinstance(out2, ilist.IList)
131+
assert out2.elem == types.Float
132+
assert np.allclose(out2, truth2)
133+
33134

34135
@basic.union([vmath])
35136
def acos_func(x):

0 commit comments

Comments
 (0)