@@ -5,6 +5,13 @@ function mulexprcost(ex::Expr)
5
5
base = ex. head === :call ? 10 : 1
6
6
base + length (ex. args)
7
7
end
8
+ function mul_fast_expr (args)
9
+ b = Expr (:call , :mul_fast )
10
+ for i ∈ 2 : length (args)
11
+ push! (b. args, args[i])
12
+ end
13
+ b
14
+ end
8
15
function mulexpr (mulexargs)
9
16
a = (mulexargs[1 ]):: Union{Symbol,Expr,Number}
10
17
if length (mulexargs) == 2
@@ -25,17 +32,17 @@ function mulexpr(mulexargs)
25
32
return (c, Expr (:call , :mul_fast , a, b))
26
33
end
27
34
else
28
- return (a, Expr ( :call , :mul_fast , @view ( mulexargs[ 2 : end ]) . .. ) :: Expr )
35
+ return (a, mul_fast_expr ( mulexargs) )
29
36
end
30
37
a = (mulexargs[1 ]):: Union{Symbol,Expr,Number}
31
38
b = if length (mulexargs) == 2 # two arg mul
32
39
(mulexargs[2 ]):: Union{Symbol,Expr,Number}
33
40
else
34
- Expr ( :call , :mul_fast , @view ( mulexargs[ 2 : end ]) . .. ) :: Expr
41
+ mul_fast_expr ( mulexargs)
35
42
end
36
43
a, b
37
44
end
38
- function append_args_skip! (call, args, i)
45
+ function append_args_skip! (call, args, i, mod )
39
46
for j ∈ eachindex (args)
40
47
j == i && continue
41
48
push! (call. args, args[j])
@@ -44,18 +51,29 @@ function append_args_skip!(call, args, i)
44
51
end
45
52
46
53
fastfunc (f) = get (VectorizationBase. FASTDICT, f, f)
47
- function make_fast! (call:: Expr )
48
- call. args[1 ] = fastfunc (first (call. args))
49
- nothing
54
+ function muladd_arguments! (argv, mod, f = first (argv))
55
+ if f === :*
56
+ argv[1 ] = :mul_fast
57
+ else
58
+ argv[1 ] = fastfunc (f)
59
+ end
60
+ for i ∈ 2 : length (argv)
61
+ a = argv[i]
62
+ a isa Expr || continue
63
+ argv[i] = capture_muladd (a:: Expr , mod)
64
+ end
50
65
end
51
66
52
- function recursive_muladd_search! (call, argv, cnmul:: Bool = false , csub:: Bool = false )
53
- length (argv) < 3 && (make_fast! (call); return length (call. args) == 4 , cnmul, csub)
67
+ function recursive_muladd_search! (call, argv, mod, cnmul:: Bool = false , csub:: Bool = false )
68
+ if length (argv) < 3
69
+ muladd_arguments! (argv, mod)
70
+ return length (call. args) == 4 , cnmul, csub
71
+ end
54
72
fun = first (argv)
55
73
isadd = fun === :+ || fun === :add_fast || fun === :vadd || (fun == :(Base. FastMath. add_fast)):: Bool
56
74
issub = fun === :- || fun === :sub_fast || fun === :vsub || (fun == :(Base. FastMath. sub_fast)):: Bool
57
75
if ! (isadd | issub)
58
- argv[ 1 ] = fastfunc ( fun)
76
+ muladd_arguments! ( argv, mod, fun)
59
77
return length (call. args) == 4 , cnmul, csub
60
78
end
61
79
exargs = @view (argv[2 : end ])
@@ -72,29 +90,29 @@ function recursive_muladd_search!(call, argv, cnmul::Bool = false, csub::Bool =
72
90
if length (exargs) == 2
73
91
push! (call. args, exargs[3 - i])
74
92
else
75
- push! (call. args, append_args_skip! (Expr (:call , :add_fast ), exargs, i))
93
+ push! (call. args, append_args_skip! (Expr (:call , :add_fast ), exargs, i, mod ))
76
94
end
77
95
if issub
78
96
csub = i == 1
79
97
cnmul = ! csub
80
98
end
81
99
return true , cnmul, csub
82
100
elseif isadd
83
- found, cnmul, csub = recursive_muladd_search! (call, exa)
101
+ found, cnmul, csub = recursive_muladd_search! (call, exa, mod )
84
102
if found
85
103
if csub
86
104
call. args[4 ] = if length (exargs) == 2
87
105
Expr (:call , :sub_fast , exargs[3 - i], call. args[4 ])
88
106
else
89
- Expr (:call , :sub_fast , append_args_skip! (Expr (:call , :add_fast ), exargs, i), call. args[4 ])
107
+ Expr (:call , :sub_fast , append_args_skip! (Expr (:call , :add_fast ), exargs, i, mod ), call. args[4 ])
90
108
end
91
109
else
92
- call. args[4 ] = append_args_skip! (Expr (:call , :add_fast , call. args[4 ]), exargs, i)
110
+ call. args[4 ] = append_args_skip! (Expr (:call , :add_fast , call. args[4 ]), exargs, i, mod )
93
111
end
94
112
return true , cnmul, false
95
113
end
96
114
elseif issub
97
- found, cnmul, csub = recursive_muladd_search! (call, exa)
115
+ found, cnmul, csub = recursive_muladd_search! (call, exa, mod )
98
116
if found
99
117
if i == 1
100
118
if csub
@@ -119,10 +137,11 @@ function recursive_muladd_search!(call, argv, cnmul::Bool = false, csub::Bool =
119
137
length (call. args) == 4 , cnmul, csub
120
138
end
121
139
122
- function capture_muladd (ex:: Expr , mod)
140
+ function capture_a_muladd (ex:: Expr , mod)
123
141
call = Expr (:call , Symbol (" " ), Symbol (" " ), Symbol (" " ))
124
- found, nmul, sub = recursive_muladd_search! (call, ex. args)
125
- found || return ex
142
+ found, nmul, sub = recursive_muladd_search! (call, ex. args, mod)
143
+ found || return false , ex
144
+ # found || return ex
126
145
# a, b, c = call.args[2], call.args[3], call.args[4]
127
146
# call.args[2], call.args[3], call.args[4] = c, a, b
128
147
f = if nmul && sub
@@ -139,7 +158,13 @@ function capture_muladd(ex::Expr, mod)
139
158
else
140
159
call. args[1 ] = Expr (:(.), mod, QuoteNote (f))# _fast))
141
160
end
142
- call
161
+ true , call
162
+ end
163
+ function capture_muladd (ex:: Expr , mod)
164
+ while true
165
+ found, ex = capture_a_muladd (ex, mod)
166
+ found || return ex
167
+ end
143
168
end
144
169
145
170
contract_pass! (:: Any , :: Any ) = nothing
0 commit comments