@@ -76,80 +76,85 @@ function muladd_arguments!(argv, mod, f = first(argv))
76
76
end
77
77
78
78
function recursive_muladd_search! (call, argv, mod, cnmul:: Bool = false , csub:: Bool = false )
79
- if length (argv) < 3
80
- muladd_arguments! (argv, mod)
81
- return length (call. args) == 4 , cnmul, csub
82
- end
83
- fun = first (argv)
84
- isadd = fun === :+ || fun === :add_fast || fun === :vadd || (fun == :(Base. FastMath. add_fast)):: Bool
85
- issub = fun === :- || fun === :sub_fast || fun === :vsub || (fun == :(Base. FastMath. sub_fast)):: Bool
86
- if isadd
87
- argv[1 ] = :add_fast
88
- elseif issub
89
- argv[1 ] = :sub_fast
90
- else
91
- muladd_arguments! (argv, mod, fun)
92
- return length (call. args) == 4 , cnmul, csub
79
+ if length (argv) < 3
80
+ muladd_arguments! (argv, mod)
81
+ return length (call. args) == 4 , cnmul, csub
82
+ end
83
+ fun = first (argv)
84
+ isadd = fun === :+ || fun === :add_fast || fun === :vadd || (fun == :(Base. FastMath. add_fast)):: Bool
85
+ issub = fun === :- || fun === :sub_fast || fun === :vsub || (fun == :(Base. FastMath. sub_fast)):: Bool
86
+ if isadd
87
+ argv[1 ] = :add_fast
88
+ elseif issub
89
+ argv[1 ] = :sub_fast
90
+ else
91
+ muladd_arguments! (argv, mod, fun)
92
+ return length (call. args) == 4 , cnmul, csub
93
+ end
94
+ exargs = @view (argv[2 : end ])
95
+ for i ∈ eachindex (exargs)
96
+ if exargs[i] === :Inf
97
+ exargs[i] === Inf
93
98
end
94
- exargs = @view (argv[2 : end ])
95
- issub && @assert length (exargs) == 2
96
- for (i,ex) ∈ enumerate (exargs)
97
- if ex isa Expr && ex. head === :call
98
- exa = ex. args
99
- f = first (exa)
100
- exav = @view (exa[2 : end ])
101
- if f === :* || f === :mul_fast || f === :vmul || (f == :(Base. FastMath. mul_fast)):: Bool
102
- a, b = mulexpr (exav)
103
- call. args[2 ] = a
104
- call. args[3 ] = b
105
- if length (exargs) == 2
106
- push! (call. args, exargs[3 - i])
107
- else
108
- push! (call. args, append_args_skip! (Expr (:call , :add_fast ), exargs, i, mod))
109
- end
110
- if issub
111
- csub = i == 1
112
- cnmul = ! csub
113
- end
114
- return true , cnmul, csub
115
- elseif isadd
116
- found, cnmul, csub = recursive_muladd_search! (call, exa, mod)
117
- if found
118
- if csub
119
- call. args[4 ] = if length (exargs) == 2
120
- Expr (:call , :sub_fast , exargs[3 - i], call. args[4 ])
121
- else
122
- Expr (:call , :sub_fast , append_args_skip! (Expr (:call , :add_fast ), exargs, i, mod), call. args[4 ])
123
- end
124
- else
125
- call. args[4 ] = append_args_skip! (Expr (:call , :add_fast , call. args[4 ]), exargs, i, mod)
126
- end
127
- return true , cnmul, false
128
- end
129
- elseif issub
130
- found, cnmul, csub = recursive_muladd_search! (call, exa, mod)
131
- if found
132
- if i == 1
133
- if csub
134
- call. args[4 ] = Expr (:call , :add_fast , call. args[4 ], exargs[3 - i])
135
- else
136
- call. args[4 ] = Expr (:call , :sub_fast , call. args[4 ], exargs[3 - i])
137
- end
138
- else
139
- cnmul = ! cnmul
140
- if csub
141
- call. args[4 ] = Expr (:call , :add_fast , exargs[3 - i], call. args[4 ])
142
- else
143
- call. args[4 ] = Expr (:call , :sub_fast , exargs[3 - i], call. args[4 ])
144
- end
145
- csub = false
146
- end
147
- return true , cnmul, csub
148
- end
99
+ end
100
+ issub && @assert length (exargs) == 2
101
+ for (i,ex) ∈ enumerate (exargs)
102
+ if ex isa Expr && ex. head === :call
103
+ exa = ex. args
104
+ f = first (exa)
105
+ exav = @view (exa[2 : end ])
106
+ if f === :* || f === :mul_fast || f === :vmul || (f == :(Base. FastMath. mul_fast)):: Bool
107
+ a, b = mulexpr (exav)
108
+ call. args[2 ] = a
109
+ call. args[3 ] = b
110
+ if length (exargs) == 2
111
+ push! (call. args, exargs[3 - i])
112
+ else
113
+ push! (call. args, append_args_skip! (Expr (:call , :add_fast ), exargs, i, mod))
114
+ end
115
+ if issub
116
+ csub = i == 1
117
+ cnmul = ! csub
118
+ end
119
+ return true , cnmul, csub
120
+ elseif isadd
121
+ found, cnmul, csub = recursive_muladd_search! (call, exa, mod)
122
+ if found
123
+ if csub
124
+ call. args[4 ] = if length (exargs) == 2
125
+ Expr (:call , :sub_fast , exargs[3 - i], call. args[4 ])
126
+ else
127
+ Expr (:call , :sub_fast , append_args_skip! (Expr (:call , :add_fast ), exargs, i, mod), call. args[4 ])
149
128
end
129
+ else
130
+ call. args[4 ] = append_args_skip! (Expr (:call , :add_fast , call. args[4 ]), exargs, i, mod)
131
+ end
132
+ return true , cnmul, false
150
133
end
134
+ elseif issub
135
+ found, cnmul, csub = recursive_muladd_search! (call, exa, mod)
136
+ if found
137
+ if i == 1
138
+ if csub
139
+ call. args[4 ] = Expr (:call , :add_fast , call. args[4 ], exargs[3 - i])
140
+ else
141
+ call. args[4 ] = Expr (:call , :sub_fast , call. args[4 ], exargs[3 - i])
142
+ end
143
+ else
144
+ cnmul = ! cnmul
145
+ if csub
146
+ call. args[4 ] = Expr (:call , :add_fast , exargs[3 - i], call. args[4 ])
147
+ else
148
+ call. args[4 ] = Expr (:call , :sub_fast , exargs[3 - i], call. args[4 ])
149
+ end
150
+ csub = false
151
+ end
152
+ return true , cnmul, csub
153
+ end
154
+ end
151
155
end
152
- length (call. args) == 4 , cnmul, csub
156
+ end
157
+ length (call. args) == 4 , cnmul, csub
153
158
end
154
159
155
160
function capture_a_muladd (ex:: Expr , mod)
190
195
function capture_muladd (ex:: Expr , mod)
191
196
while true
192
197
ex. head === :ref && return ex
198
+ if Meta. isexpr (ex, :call , 2 )
199
+ if (ex. args[1 ] === :(- ))
200
+ if (ex. args[2 ] isa Number)
201
+ return - ex. args[2 ]
202
+ elseif ex. args[2 ] === :Inf
203
+ return - Inf
204
+ end
205
+ end
206
+ end
193
207
found, ex = capture_a_muladd (ex, mod)
194
208
found || return ex
195
209
end
0 commit comments