@@ -18,11 +18,6 @@ def add_scalar_rhs_typed(x: IList[float, Any], y: float):
1818 return x + y
1919
2020
21- @basic .union ([vmath ])
22- def add_two_lists ():
23- return add_kernel (x = [0 , 1 , 2 ], y = [3 , 4 , 5 ])
24-
25-
2621@basic .union ([vmath ])(aggressive = True , typeinfer = True )
2722def add_scalar_lhs ():
2823 return add_kernel (x = 3.0 , y = [3.0 , 4 , 5 ])
@@ -41,4 +36,49 @@ def test_add_scalar_lhs():
4136def test_typed_kernel_add ():
4237 VMathDesugar (add_scalar_rhs_typed .dialects ).unsafe_run (add_scalar_rhs_typed )
4338 add_scalar_rhs_typed .print ()
44- print (add_scalar_rhs_typed (IList ([0 , 1 , 2 ]), 3.1 ))
39+ res = add_scalar_rhs_typed (IList ([0 , 1 , 2 ]), 3.1 )
40+ assert np .allclose (np .asarray (res ), np .asarray ([3.1 , 4.1 , 5.1 ]))
41+
42+
43+ @basic .union ([vmath ])
44+ def add_two_lists ():
45+ return add_kernel (x = [0 , 1 , 2 ], y = [3 , 4 , 5 ])
46+
47+
48+ def test_add_lists ():
49+ VMathDesugar (add_two_lists .dialects ).unsafe_run (add_two_lists )
50+ res = add_two_lists ()
51+ assert np .allclose (np .asarray (res ), np .array ([0 , 1 , 2 , 3 , 4 , 5 ]))
52+
53+
54+ @basic .union ([vmath ])
55+ def sub_scalar_rhs_typed (x : IList [float , Any ], y : float ):
56+ return x - y
57+
58+
59+ def test_sub_scalar_typed ():
60+ VMathDesugar (sub_scalar_rhs_typed .dialects ).unsafe_run (sub_scalar_rhs_typed )
61+ res = sub_scalar_rhs_typed (IList ([0 , 1 , 2 ]), 3.1 )
62+ assert np .allclose (np .asarray (res ), np .asarray ([- 3.1 , - 2.1 , - 1.1 ]))
63+
64+
65+ @basic .union ([vmath ])
66+ def mult_scalar_lhs_typed (x : float , y : IList [float , Any ]):
67+ return x * y
68+
69+
70+ def test_mult_scalar_typed ():
71+ VMathDesugar (mult_scalar_lhs_typed .dialects ).unsafe_run (mult_scalar_lhs_typed )
72+ res = mult_scalar_lhs_typed (3 , IList ([0 , 1 , 2 ]))
73+ assert np .allclose (np .asarray (res ), np .asarray ([0 , 3 , 6 ]))
74+
75+
76+ @basic .union ([vmath ])
77+ def div_scalar_lhs_typed (x : float , y : IList [float , Any ]):
78+ return x / y
79+
80+
81+ def test_div_scalar_typed ():
82+ VMathDesugar (div_scalar_lhs_typed .dialects ).unsafe_run (div_scalar_lhs_typed )
83+ res = div_scalar_lhs_typed (3 , IList ([1 , 1.5 , 2 ]))
84+ assert np .allclose (np .asarray (res ), np .asarray ([3 , 2 , 1.5 ]))
0 commit comments