@@ -21,6 +21,7 @@ def main():
2121 return q
2222
2323 helper .print ()
24+ main .print ()
2425
2526 assert isinstance (helper .code , func .Function )
2627
@@ -37,3 +38,82 @@ def main():
3738 count_mults_in_main += isinstance (stmt , squin .op .stmts .Mult )
3839
3940 assert count_mults_in_main == 1
41+
42+
43+ def test_scale_rewrite ():
44+
45+ @squin .kernel
46+ def simple_rmul ():
47+ x = squin .op .x ()
48+ y = 2 * x
49+ return y
50+
51+ simple_rmul .print ()
52+
53+ assert isinstance (simple_rmul .code , func .Function )
54+
55+ simple_rmul_stmts = list (simple_rmul .code .body .stmts ())
56+ assert any (
57+ map (lambda stmt : isinstance (stmt , squin .op .stmts .Scale ), simple_rmul_stmts )
58+ )
59+ assert not any (
60+ map (lambda stmt : isinstance (stmt , squin .op .stmts .Mult ), simple_rmul_stmts )
61+ )
62+ assert not any (map (lambda stmt : isinstance (stmt , py .Mult ), simple_rmul_stmts ))
63+
64+ @squin .kernel
65+ def simple_lmul ():
66+ x = squin .op .x ()
67+ y = x * 2
68+ return y
69+
70+ simple_lmul .print ()
71+
72+ assert isinstance (simple_lmul .code , func .Function )
73+
74+ simple_lmul_stmts = list (simple_lmul .code .body .stmts ())
75+ assert any (
76+ map (lambda stmt : isinstance (stmt , squin .op .stmts .Scale ), simple_lmul_stmts )
77+ )
78+ assert not any (
79+ map (lambda stmt : isinstance (stmt , squin .op .stmts .Mult ), simple_lmul_stmts )
80+ )
81+ assert not any (map (lambda stmt : isinstance (stmt , py .Mult ), simple_lmul_stmts ))
82+
83+ @squin .kernel
84+ def scale_mult ():
85+ x = squin .op .x ()
86+ y = squin .op .y ()
87+ return 2 * (x * y )
88+
89+ assert isinstance (scale_mult .code , func .Function )
90+
91+ scale_mult_stmts = list (scale_mult .code .body .stmts ())
92+ assert (
93+ sum (map (lambda stmt : isinstance (stmt , squin .op .stmts .Scale ), scale_mult_stmts ))
94+ == 1
95+ )
96+ assert (
97+ sum (map (lambda stmt : isinstance (stmt , squin .op .stmts .Mult ), scale_mult_stmts ))
98+ == 1
99+ )
100+
101+ @squin .kernel
102+ def scale_mult2 ():
103+ x = squin .op .x ()
104+ y = squin .op .y ()
105+ return 2 * x * y
106+
107+ scale_mult2 .print ()
108+
109+ assert isinstance (scale_mult2 .code , func .Function )
110+
111+ scale_mult2_stmts = list (scale_mult2 .code .body .stmts ())
112+ assert (
113+ sum (map (lambda stmt : isinstance (stmt , squin .op .stmts .Scale ), scale_mult2_stmts ))
114+ == 1
115+ )
116+ assert (
117+ sum (map (lambda stmt : isinstance (stmt , squin .op .stmts .Mult ), scale_mult2_stmts ))
118+ == 1
119+ )
0 commit comments