Skip to content

Commit cf9824e

Browse files
committed
add more tests
1 parent 54500b2 commit cf9824e

File tree

1 file changed

+46
-6
lines changed

1 file changed

+46
-6
lines changed

test/dialects/vmath/test_desugar.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,6 @@ def add_scalar_rhs_typed(x: IList[float, Any], y: float):
1818
return x + y
1919

2020

21-
@basic.union([vmath])
22-
def add_two_lists():
23-
return add_kernel(x=[0, 1, 2], y=[3, 4, 5])
24-
25-
2621
@basic.union([vmath])(aggressive=True, typeinfer=True)
2722
def add_scalar_lhs():
2823
return add_kernel(x=3.0, y=[3.0, 4, 5])
@@ -41,4 +36,49 @@ def test_add_scalar_lhs():
4136
def test_typed_kernel_add():
4237
VMathDesugar(add_scalar_rhs_typed.dialects).unsafe_run(add_scalar_rhs_typed)
4338
add_scalar_rhs_typed.print()
44-
print(add_scalar_rhs_typed(IList([0, 1, 2]), 3.1))
39+
res = add_scalar_rhs_typed(IList([0, 1, 2]), 3.1)
40+
assert np.allclose(np.asarray(res), np.asarray([3.1, 4.1, 5.1]))
41+
42+
43+
@basic.union([vmath])
44+
def add_two_lists():
45+
return add_kernel(x=[0, 1, 2], y=[3, 4, 5])
46+
47+
48+
def test_add_lists():
49+
VMathDesugar(add_two_lists.dialects).unsafe_run(add_two_lists)
50+
res = add_two_lists()
51+
assert np.allclose(np.asarray(res), np.array([0, 1, 2, 3, 4, 5]))
52+
53+
54+
@basic.union([vmath])
55+
def sub_scalar_rhs_typed(x: IList[float, Any], y: float):
56+
return x - y
57+
58+
59+
def test_sub_scalar_typed():
60+
VMathDesugar(sub_scalar_rhs_typed.dialects).unsafe_run(sub_scalar_rhs_typed)
61+
res = sub_scalar_rhs_typed(IList([0, 1, 2]), 3.1)
62+
assert np.allclose(np.asarray(res), np.asarray([-3.1, -2.1, -1.1]))
63+
64+
65+
@basic.union([vmath])
66+
def mult_scalar_lhs_typed(x: float, y: IList[float, Any]):
67+
return x * y
68+
69+
70+
def test_mult_scalar_typed():
71+
VMathDesugar(mult_scalar_lhs_typed.dialects).unsafe_run(mult_scalar_lhs_typed)
72+
res = mult_scalar_lhs_typed(3, IList([0, 1, 2]))
73+
assert np.allclose(np.asarray(res), np.asarray([0, 3, 6]))
74+
75+
76+
@basic.union([vmath])
77+
def div_scalar_lhs_typed(x: float, y: IList[float, Any]):
78+
return x / y
79+
80+
81+
def test_div_scalar_typed():
82+
VMathDesugar(div_scalar_lhs_typed.dialects).unsafe_run(div_scalar_lhs_typed)
83+
res = div_scalar_lhs_typed(3, IList([1, 1.5, 2]))
84+
assert np.allclose(np.asarray(res), np.asarray([3, 2, 1.5]))

0 commit comments

Comments
 (0)