Skip to content

Commit 3c63a48

Browse files
authored
fix wrong type and interp for pow and add missing stmts for vmath (#476)
1 parent 45f62c1 commit 3c63a48

File tree

4 files changed

+80
-7
lines changed

4 files changed

+80
-7
lines changed

src/kirin/dialects/vmath/__init__.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,7 @@ def log2(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
125125

126126

127127
@lowering.wraps(stmts.pow)
128-
def pow(
129-
x: ilist.IList[float, ListLen], y: ilist.IList[float, ListLen]
130-
) -> ilist.IList[float, ListLen]: ...
128+
def pow(x: ilist.IList[float, ListLen], y: float) -> ilist.IList[float, ListLen]: ...
131129

132130

133131
@lowering.wraps(stmts.radians)
@@ -162,3 +160,15 @@ def tanh(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
162160

163161
@lowering.wraps(stmts.trunc)
164162
def trunc(x: ilist.IList[float, ListLen]) -> ilist.IList[float, ListLen]: ...
163+
164+
165+
@lowering.wraps(stmts.scale)
166+
def scale(
167+
value: float, x: ilist.IList[float, ListLen]
168+
) -> ilist.IList[float, ListLen]: ...
169+
170+
171+
@lowering.wraps(stmts.offset)
172+
def offset(
173+
value: float, x: ilist.IList[float, ListLen]
174+
) -> ilist.IList[float, ListLen]: ...

src/kirin/dialects/vmath/interp.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,11 @@ def log2(self, interp, frame: Frame, stmt: stmts.log2):
187187

188188
@impl(stmts.pow)
189189
def pow(self, interp, frame: Frame, stmt: stmts.pow):
190-
values = frame.get_values(stmt.args)
190+
x = frame.get(stmt.x)
191+
y = frame.get(stmt.y)
191192
return (
192193
ilist.IList(
193-
np.pow(np.asarray(values[0]), np.asarray(values[1])).tolist(),
194+
np.pow(np.asarray(x), y).tolist(),
194195
elem=types.Float,
195196
),
196197
)
@@ -243,3 +244,15 @@ def trunc(self, interp, frame: Frame, stmt: stmts.trunc):
243244
return (
244245
ilist.IList(np.trunc(np.asarray(values[0])).tolist(), elem=types.Float),
245246
)
247+
248+
@impl(stmts.scale)
249+
def scale(self, interp, frame: Frame, stmt: stmts.scale):
250+
a = frame.get(stmt.value)
251+
x = frame.get(stmt.x)
252+
return (ilist.IList((np.asarray(x) * a).tolist(), elem=types.Float),)
253+
254+
@impl(stmts.offset)
255+
def offset(self, interp, frame: Frame, stmt: stmts.offset):
256+
a = frame.get(stmt.value)
257+
x = frame.get(stmt.x)
258+
return (ilist.IList((np.asarray(x) + a).tolist(), elem=types.Float),)

src/kirin/dialects/vmath/stmts.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,3 +360,25 @@ class trunc(ir.Statement):
360360
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
361361
x: ir.SSAValue = info.argument(ilist.IListType[types.Float, ListLen])
362362
result: ir.ResultValue = info.result(ilist.IListType[types.Float, ListLen])
363+
364+
365+
@statement(dialect=dialect)
366+
class scale(ir.Statement):
367+
"""scale with a scalar statement"""
368+
369+
name = "scale"
370+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
371+
x: ir.SSAValue = info.argument(ilist.IListType[types.Float, ListLen])
372+
value: ir.SSAValue = info.argument(types.Float)
373+
result: ir.ResultValue = info.result(ilist.IListType[types.Float, ListLen])
374+
375+
376+
@statement(dialect=dialect)
377+
class offset(ir.Statement):
378+
"""offset with a scalar statement"""
379+
380+
name = "offset"
381+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
382+
x: ir.SSAValue = info.argument(ilist.IListType[types.Float, ListLen])
383+
value: ir.SSAValue = info.argument(types.Float)
384+
result: ir.ResultValue = info.result(ilist.IListType[types.Float, ListLen])

test/dialects/vmath/test_basic.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,11 +370,11 @@ def pow_func(x, y):
370370
def test_pow():
371371
truth = np.pow(
372372
ilist.IList([0.42, 0.87, 0.32], elem=types.Float),
373-
ilist.IList([0.42, 0.87, 0.32], elem=types.Float),
373+
3.33,
374374
)
375375
out = pow_func(
376376
ilist.IList([0.42, 0.87, 0.32], elem=types.Float),
377-
ilist.IList([0.42, 0.87, 0.32], elem=types.Float),
377+
3.33,
378378
)
379379
assert isinstance(out, ilist.IList)
380380
assert out.elem == types.Float
@@ -489,3 +489,31 @@ def test_trunc():
489489
assert isinstance(out, ilist.IList)
490490
assert out.elem == types.Float
491491
assert np.allclose(out, truth)
492+
493+
494+
@basic.union([vmath])
495+
def scale_func(x, y):
496+
return vmath.scale(value=y, x=x)
497+
498+
499+
def test_scale():
500+
a = 3.3
501+
truth = np.array(ilist.IList([0.42, 0.87, 0.32], elem=types.Float)) * a
502+
out = scale_func(ilist.IList([0.42, 0.87, 0.32], elem=types.Float), a)
503+
assert isinstance(out, ilist.IList)
504+
assert out.elem == types.Float
505+
assert np.allclose(out, truth)
506+
507+
508+
@basic.union([vmath])
509+
def offset_func(x, y):
510+
return vmath.offset(value=y, x=x)
511+
512+
513+
def test_offset():
514+
a = 3.3
515+
truth = np.array(ilist.IList([0.42, 0.87, 0.32], elem=types.Float)) + a
516+
out = offset_func(ilist.IList([0.42, 0.87, 0.32], elem=types.Float), a)
517+
assert isinstance(out, ilist.IList)
518+
assert out.elem == types.Float
519+
assert np.allclose(out, truth)

0 commit comments

Comments
 (0)