@@ -3,26 +3,21 @@ module EvaluateEquationModule
33import LoopVectorization: @turbo , indices
44import .. EquationModule: Node, string_tree
55import .. OperatorEnumModule: OperatorEnum, GenericOperatorEnum
6- import .. UtilsModule: @return_on_false , @maybe_turbo , is_bad_array
6+ import .. UtilsModule: @return_on_false , @maybe_turbo , is_bad_array, fill_similar
77import .. EquationUtilsModule: is_constant
88
9- macro return_on_check (val, T, n)
10- # This will generate the following code:
11- # if !isfinite(val)
12- # return (Array{T, 1}(undef, n), false)
13- # end
14-
9+ macro return_on_check (val, X)
1510 :(
1611 if ! isfinite ($ (esc (val)))
17- return (Array { $(esc(T )),1} (undef, $ (esc (n) )), false )
12+ return (similar ( $ (esc (X )), axes ( $ (esc (X)), 2 )), false )
1813 end
1914 )
2015end
2116
22- macro return_on_nonfinite_array (array, T, n )
17+ macro return_on_nonfinite_array (array)
2318 :(
2419 if is_bad_array ($ (esc (array)))
25- return (Array { $(esc(T)),1} (undef, $ ( esc (n) )), false )
20+ return ($ (esc (array )), false )
2621 end
2722 )
2823end
@@ -64,15 +59,14 @@ which speed up evaluation significantly.
6459function eval_tree_array (
6560 tree:: Node{T} , cX:: AbstractMatrix{T} , operators:: OperatorEnum ; turbo:: Bool = false
6661):: Tuple{AbstractVector{T},Bool} where {T<: Number }
67- n = size (cX, 2 )
6862 if turbo
6963 @assert T in (Float32, Float64)
7064 end
7165 result, finished = _eval_tree_array (
7266 tree, cX, operators, (turbo ? Val (true ) : Val (false ))
7367 )
7468 @return_on_false finished result
75- @return_on_nonfinite_array result T n
69+ @return_on_nonfinite_array result
7670 return result, finished
7771end
7872function eval_tree_array (
@@ -81,23 +75,22 @@ function eval_tree_array(
8175 T = promote_type (T1, T2)
8276 @warn " Warning: eval_tree_array received mixed types: tree=$(T1) and data=$(T2) ."
8377 tree = convert (Node{T}, tree)
84- cX = convert (AbstractMatrix{T}, cX)
78+ cX = T .( cX)
8579 return eval_tree_array (tree, cX, operators; turbo= turbo)
8680end
8781
8882function _eval_tree_array (
8983 tree:: Node{T} , cX:: AbstractMatrix{T} , operators:: OperatorEnum , :: Val{turbo}
9084):: Tuple{AbstractVector{T},Bool} where {T<: Number ,turbo}
91- n = size (cX, 2 )
9285 # First, we see if there are only constants in the tree - meaning
9386 # we can just return the constant result.
9487 if tree. degree == 0
9588 return deg0_eval (tree, cX)
9689 elseif is_constant (tree)
9790 # Speed hack for constant trees.
9891 result, flag = _eval_constant_tree (tree, operators)
99- ! flag && return Array {T,1} (undef, size (cX, 2 )), false
100- return fill (result, size (cX, 2 )), true
92+ ! flag && return similar (cX, axes (cX, 2 )), false
93+ return fill_similar (result, cX, axes (cX, 2 )), true
10194 elseif tree. degree == 1
10295 op = operators. unaops[tree. op]
10396 if tree. l. degree == 2 && tree. l. l. degree == 0 && tree. l. r. degree == 0
@@ -113,7 +106,7 @@ function _eval_tree_array(
113106 # op(x), for any x.
114107 (cumulator, complete) = _eval_tree_array (tree. l, cX, operators, Val (turbo))
115108 @return_on_false complete cumulator
116- @return_on_nonfinite_array cumulator T n
109+ @return_on_nonfinite_array cumulator
117110 return deg1_eval (cumulator, op, Val (turbo))
118111
119112 elseif tree. degree == 2
@@ -125,22 +118,22 @@ function _eval_tree_array(
125118 elseif tree. r. degree == 0
126119 (cumulator_l, complete) = _eval_tree_array (tree. l, cX, operators, Val (turbo))
127120 @return_on_false complete cumulator_l
128- @return_on_nonfinite_array cumulator_l T n
121+ @return_on_nonfinite_array cumulator_l
129122 # op(x, y), where y is a constant or variable but x is not.
130123 return deg2_r0_eval (tree, cumulator_l, cX, op, Val (turbo))
131124 elseif tree. l. degree == 0
132125 (cumulator_r, complete) = _eval_tree_array (tree. r, cX, operators, Val (turbo))
133126 @return_on_false complete cumulator_r
134- @return_on_nonfinite_array cumulator_r T n
127+ @return_on_nonfinite_array cumulator_r
135128 # op(x, y), where x is a constant or variable but y is not.
136129 return deg2_l0_eval (tree, cumulator_r, cX, op, Val (turbo))
137130 end
138131 (cumulator_l, complete) = _eval_tree_array (tree. l, cX, operators, Val (turbo))
139132 @return_on_false complete cumulator_l
140- @return_on_nonfinite_array cumulator_l T n
133+ @return_on_nonfinite_array cumulator_l
141134 (cumulator_r, complete) = _eval_tree_array (tree. r, cX, operators, Val (turbo))
142135 @return_on_false complete cumulator_r
143- @return_on_nonfinite_array cumulator_r T n
136+ @return_on_nonfinite_array cumulator_r
144137 # op(x, y), for any x or y
145138 return deg2_eval (cumulator_l, cumulator_r, op, Val (turbo))
146139 end
@@ -170,8 +163,7 @@ function deg0_eval(
170163 tree:: Node{T} , cX:: AbstractMatrix{T}
171164):: Tuple{AbstractVector{T},Bool} where {T<: Number }
172165 if tree. constant
173- n = size (cX, 2 )
174- return (fill (tree. val:: T , n), true )
166+ return (fill_similar (tree. val:: T , cX, axes (cX, 2 )), true )
175167 else
176168 return (cX[tree. feature, :], true )
177169 end
@@ -180,22 +172,21 @@ end
180172function deg1_l2_ll0_lr0_eval (
181173 tree:: Node{T} , cX:: AbstractMatrix{T} , op:: F , op_l:: F2 , :: Val{turbo}
182174):: Tuple{AbstractVector{T},Bool} where {T<: Number ,F,F2,turbo}
183- n = size (cX, 2 )
184175 if tree. l. l. constant && tree. l. r. constant
185176 val_ll = tree. l. l. val:: T
186177 val_lr = tree. l. r. val:: T
187- @return_on_check val_ll T n
188- @return_on_check val_lr T n
178+ @return_on_check val_ll cX
179+ @return_on_check val_lr cX
189180 x_l = op_l (val_ll, val_lr):: T
190- @return_on_check x_l T n
181+ @return_on_check x_l cX
191182 x = op (x_l):: T
192- @return_on_check x T n
193- return (fill (x, n ), true )
183+ @return_on_check x cX
184+ return (fill_similar (x, cX, axes (cX, 2 ) ), true )
194185 elseif tree. l. l. constant
195186 val_ll = tree. l. l. val:: T
196- @return_on_check val_ll T n
187+ @return_on_check val_ll cX
197188 feature_lr = tree. l. r. feature
198- cumulator = Array {T,1} (undef, n )
189+ cumulator = similar (cX, axes (cX, 2 ) )
199190 @maybe_turbo turbo for j in indices ((cX, cumulator), (2 , 1 ))
200191 x_l = op_l (val_ll, cX[feature_lr, j]):: T
201192 x = isfinite (x_l) ? op (x_l):: T : T (Inf )
@@ -205,8 +196,8 @@ function deg1_l2_ll0_lr0_eval(
205196 elseif tree. l. r. constant
206197 feature_ll = tree. l. l. feature
207198 val_lr = tree. l. r. val:: T
208- @return_on_check val_lr T n
209- cumulator = Array {T,1} (undef, n )
199+ @return_on_check val_lr cX
200+ cumulator = similar (cX, axes (cX, 2 ) )
210201 @maybe_turbo turbo for j in indices ((cX, cumulator), (2 , 1 ))
211202 x_l = op_l (cX[feature_ll, j], val_lr):: T
212203 x = isfinite (x_l) ? op (x_l):: T : T (Inf )
@@ -216,7 +207,7 @@ function deg1_l2_ll0_lr0_eval(
216207 else
217208 feature_ll = tree. l. l. feature
218209 feature_lr = tree. l. r. feature
219- cumulator = Array {T,1} (undef, n )
210+ cumulator = similar (cX, axes (cX, 2 ) )
220211 @maybe_turbo turbo for j in indices ((cX, cumulator), (2 , 1 ))
221212 x_l = op_l (cX[feature_ll, j], cX[feature_lr, j]):: T
222213 x = isfinite (x_l) ? op (x_l):: T : T (Inf )
@@ -230,18 +221,17 @@ end
230221function deg1_l1_ll0_eval (
231222 tree:: Node{T} , cX:: AbstractMatrix{T} , op:: F , op_l:: F2 , :: Val{turbo}
232223):: Tuple{AbstractVector{T},Bool} where {T<: Number ,F,F2,turbo}
233- n = size (cX, 2 )
234224 if tree. l. l. constant
235225 val_ll = tree. l. l. val:: T
236- @return_on_check val_ll T n
226+ @return_on_check val_ll cX
237227 x_l = op_l (val_ll):: T
238- @return_on_check x_l T n
228+ @return_on_check x_l cX
239229 x = op (x_l):: T
240- @return_on_check x T n
241- return (fill (x, n ), true )
230+ @return_on_check x cX
231+ return (fill_similar (x, cX, axes (cX, 2 ) ), true )
242232 else
243233 feature_ll = tree. l. l. feature
244- cumulator = Array {T,1} (undef, n )
234+ cumulator = similar (cX, axes (cX, 2 ) )
245235 @maybe_turbo turbo for j in indices ((cX, cumulator), (2 , 1 ))
246236 x_l = op_l (cX[feature_ll, j]):: T
247237 x = isfinite (x_l) ? op (x_l):: T : T (Inf )
@@ -255,35 +245,34 @@ end
255245function deg2_l0_r0_eval (
256246 tree:: Node{T} , cX:: AbstractMatrix{T} , op:: F , :: Val{turbo}
257247):: Tuple{AbstractVector{T},Bool} where {T<: Number ,F,turbo}
258- n = size (cX, 2 )
259248 if tree. l. constant && tree. r. constant
260249 val_l = tree. l. val:: T
261- @return_on_check val_l T n
250+ @return_on_check val_l cX
262251 val_r = tree. r. val:: T
263- @return_on_check val_r T n
252+ @return_on_check val_r cX
264253 x = op (val_l, val_r):: T
265- @return_on_check x T n
266- return (fill (x, n ), true )
254+ @return_on_check x cX
255+ return (fill_similar (x, cX, axes (cX, 2 ) ), true )
267256 elseif tree. l. constant
268- cumulator = Array {T,1} (undef, n )
257+ cumulator = similar (cX, axes (cX, 2 ) )
269258 val_l = tree. l. val:: T
270- @return_on_check val_l T n
259+ @return_on_check val_l cX
271260 feature_r = tree. r. feature
272261 @maybe_turbo turbo for j in indices ((cX, cumulator), (2 , 1 ))
273262 x = op (val_l, cX[feature_r, j]):: T
274263 cumulator[j] = x
275264 end
276265 elseif tree. r. constant
277- cumulator = Array {T,1} (undef, n )
266+ cumulator = similar (cX, axes (cX, 2 ) )
278267 feature_l = tree. l. feature
279268 val_r = tree. r. val:: T
280- @return_on_check val_r T n
269+ @return_on_check val_r cX
281270 @maybe_turbo turbo for j in indices ((cX, cumulator), (2 , 1 ))
282271 x = op (cX[feature_l, j], val_r):: T
283272 cumulator[j] = x
284273 end
285274 else
286- cumulator = Array {T,1} (undef, n )
275+ cumulator = similar (cX, axes (cX, 2 ) )
287276 feature_l = tree. l. feature
288277 feature_r = tree. r. feature
289278 @maybe_turbo turbo for j in indices ((cX, cumulator), (2 , 1 ))
298287function deg2_l0_eval (
299288 tree:: Node{T} , cumulator:: AbstractVector{T} , cX:: AbstractArray{T} , op:: F , :: Val{turbo}
300289):: Tuple{AbstractVector{T},Bool} where {T<: Number ,F,turbo}
301- n = size (cX, 2 )
302290 if tree. l. constant
303291 val = tree. l. val:: T
304- @return_on_check val T n
292+ @return_on_check val cX
305293 @maybe_turbo turbo for j in indices (cumulator)
306294 x = op (val, cumulator[j]):: T
307295 cumulator[j] = x
320308function deg2_r0_eval (
321309 tree:: Node{T} , cumulator:: AbstractVector{T} , cX:: AbstractArray{T} , op:: F , :: Val{turbo}
322310):: Tuple{AbstractVector{T},Bool} where {T<: Number ,F,turbo}
323- n = size (cX, 2 )
324311 if tree. r. constant
325312 val = tree. r. val:: T
326- @return_on_check val T n
313+ @return_on_check val cX
327314 @maybe_turbo turbo for j in indices (cumulator)
328315 x = op (cumulator[j], val):: T
329316 cumulator[j] = x
@@ -389,10 +376,9 @@ Evaluate an expression tree in a way that can be auto-differentiated.
389376function differentiable_eval_tree_array (
390377 tree:: Node{T1} , cX:: AbstractMatrix{T} , operators:: OperatorEnum
391378):: Tuple{AbstractVector{T},Bool} where {T<: Number ,T1}
392- n = size (cX, 2 )
393379 if tree. degree == 0
394380 if tree. constant
395- return (ones (T, n) .* convert (T, tree. val) , true )
381+ return (fill_similar ( one (T), cX, axes (cX, 2 )) .* tree. val, true )
396382 else
397383 return (cX[tree. feature, :], true )
398384 end
0 commit comments