@@ -3,7 +3,7 @@ module EvaluateEquationDerivativeModule
33import LoopVectorization: indices, @turbo
44import .. EquationModule: Node
55import .. OperatorEnumModule: OperatorEnum
6- import .. UtilsModule: @return_on_false2 , @maybe_turbo , is_bad_array
6+ import .. UtilsModule: @return_on_false2 , @maybe_turbo , is_bad_array, fill_similar
77import .. EquationUtilsModule: count_constants, index_constants, NodeIndex
88import .. EvaluateEquationModule: deg0_eval
99
@@ -61,7 +61,7 @@ function eval_diff_tree_array(
6161 T = promote_type (T1, T2)
6262 @warn " Warning: eval_diff_tree_array received mixed types: tree=$(T1) and data=$(T2) ."
6363 tree = convert (Node{T}, tree)
64- cX = convert (AbstractMatrix{T}, cX)
64+ cX = T .( cX)
6565 return eval_diff_tree_array (tree, cX, operators, direction; turbo= turbo)
6666end
6767
@@ -102,10 +102,12 @@ end
102102function diff_deg0_eval (
103103 tree:: Node{T} , cX:: AbstractMatrix{T} , direction:: Int
104104):: Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<: Number }
105- n = size (cX, 2 )
106105 const_part = deg0_eval (tree, cX)[1 ]
107- derivative_part =
108- ((! tree. constant) && tree. feature == direction) ? ones (T, n) : zeros (T, n)
106+ derivative_part = if ((! tree. constant) && tree. feature == direction)
107+ fill_similar (one (T), cX, axes (cX, 2 ))
108+ else
109+ fill_similar (zero (T), cX, axes (cX, 2 ))
110+ end
109111 return (const_part, derivative_part, true )
110112end
111113
@@ -118,7 +120,6 @@ function diff_deg1_eval(
118120 direction:: Int ,
119121 :: Val{turbo} ,
120122):: Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<: Number ,F,dF,turbo}
121- n = size (cX, 2 )
122123 (cumulator, dcumulator, complete) = _eval_diff_tree_array (
123124 tree. l, cX, operators, direction, Val (turbo)
124125 )
@@ -196,17 +197,11 @@ function eval_grad_tree_array(
196197 turbo:: Bool = false ,
197198):: Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<: Number }
198199 assert_autodiff_enabled (operators)
199- n = size (cX, 2 )
200- if variable
201- n_gradients = size (cX, 1 )
202- else
203- n_gradients = count_constants (tree)
204- end
200+ n_gradients = variable ? size (cX, 1 ) : count_constants (tree)
205201 index_tree = index_constants (tree, 0 )
206202 return eval_grad_tree_array (
207203 tree,
208- n,
209- n_gradients,
204+ Val (n_gradients),
210205 index_tree,
211206 cX,
212207 operators,
@@ -217,16 +212,17 @@ end
217212
218213function eval_grad_tree_array (
219214 tree:: Node{T} ,
220- n:: Int ,
221- n_gradients:: Int ,
215+ :: Val{n_gradients} ,
222216 index_tree:: NodeIndex ,
223217 cX:: AbstractMatrix{T} ,
224218 operators:: OperatorEnum ,
225219 :: Val{variable} ,
226220 :: Val{turbo} ,
227- ):: Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<: Number ,variable,turbo}
221+ ):: Tuple {
222+ AbstractVector{T},AbstractMatrix{T},Bool
223+ } where {T<: Number ,variable,turbo,n_gradients}
228224 evaluation, gradient, complete = _eval_grad_tree_array (
229- tree, n, n_gradients, index_tree, cX, operators, Val (variable), Val (turbo)
225+ tree, Val ( n_gradients) , index_tree, cX, operators, Val (variable), Val (turbo)
230226 )
231227 @return_on_false2 complete evaluation gradient
232228 return evaluation, gradient, ! (is_bad_array (evaluation) || is_bad_array (gradient))
@@ -251,21 +247,21 @@ end
251247
252248function _eval_grad_tree_array (
253249 tree:: Node{T} ,
254- n:: Int ,
255- n_gradients:: Int ,
250+ :: Val{n_gradients} ,
256251 index_tree:: NodeIndex ,
257252 cX:: AbstractMatrix{T} ,
258253 operators:: OperatorEnum ,
259254 :: Val{variable} ,
260255 :: Val{turbo} ,
261- ):: Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<: Number ,variable,turbo}
256+ ):: Tuple {
257+ AbstractVector{T},AbstractMatrix{T},Bool
258+ } where {T<: Number ,variable,turbo,n_gradients}
262259 if tree. degree == 0
263- grad_deg0_eval (tree, n, n_gradients, index_tree, cX, Val (variable))
260+ grad_deg0_eval (tree, Val ( n_gradients) , index_tree, cX, Val (variable))
264261 elseif tree. degree == 1
265262 grad_deg1_eval (
266263 tree,
267- n,
268- n_gradients,
264+ Val (n_gradients),
269265 index_tree,
270266 cX,
271267 operators. unaops[tree. op],
@@ -277,8 +273,7 @@ function _eval_grad_tree_array(
277273 else
278274 grad_deg2_eval (
279275 tree,
280- n,
281- n_gradients,
276+ Val (n_gradients),
282277 index_tree,
283278 cX,
284279 operators. binops[tree. op],
@@ -292,38 +287,40 @@ end
292287
293288function grad_deg0_eval (
294289 tree:: Node{T} ,
295- n:: Int ,
296- n_gradients:: Int ,
290+ :: Val{n_gradients} ,
297291 index_tree:: NodeIndex ,
298292 cX:: AbstractMatrix{T} ,
299293 :: Val{variable} ,
300- ):: Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<: Number ,variable}
294+ ):: Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<: Number ,variable,n_gradients }
301295 const_part = deg0_eval (tree, cX)[1 ]
302296
297+ zero_mat = hcat ((fill_similar (zero (T), cX, axes (cX, 2 )) for _ in 1 : n_gradients). .. )'
298+
303299 if variable == tree. constant
304- return (const_part, zeros (T, n_gradients, n) , true )
300+ return (const_part, zero_mat , true )
305301 end
306302
307303 index = variable ? tree. feature : index_tree. constant_index
308- derivative_part = zeros (T, n_gradients, n)
304+ derivative_part = zero_mat
309305 derivative_part[index, :] .= one (T)
310306 return (const_part, derivative_part, true )
311307end
312308
313309function grad_deg1_eval (
314310 tree:: Node{T} ,
315- n:: Int ,
316- n_gradients:: Int ,
311+ :: Val{n_gradients} ,
317312 index_tree:: NodeIndex ,
318313 cX:: AbstractMatrix{T} ,
319314 op:: F ,
320315 diff_op:: dF ,
321316 operators:: OperatorEnum ,
322317 :: Val{variable} ,
323318 :: Val{turbo} ,
324- ):: Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<: Number ,F,dF,variable,turbo}
319+ ):: Tuple {
320+ AbstractVector{T},AbstractMatrix{T},Bool
321+ } where {T<: Number ,F,dF,variable,turbo,n_gradients}
325322 (cumulator, dcumulator, complete) = eval_grad_tree_array (
326- tree. l, n, n_gradients, index_tree. l, cX, operators, Val (variable), Val (turbo)
323+ tree. l, Val ( n_gradients) , index_tree. l, cX, operators, Val (variable), Val (turbo)
327324 )
328325 @return_on_false2 complete cumulator dcumulator
329326
@@ -341,22 +338,23 @@ end
341338
342339function grad_deg2_eval (
343340 tree:: Node{T} ,
344- n:: Int ,
345- n_gradients:: Int ,
341+ :: Val{n_gradients} ,
346342 index_tree:: NodeIndex ,
347343 cX:: AbstractMatrix{T} ,
348344 op:: F ,
349345 diff_op:: dF ,
350346 operators:: OperatorEnum ,
351347 :: Val{variable} ,
352348 :: Val{turbo} ,
353- ):: Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<: Number ,F,dF,variable,turbo}
349+ ):: Tuple {
350+ AbstractVector{T},AbstractMatrix{T},Bool
351+ } where {T<: Number ,F,dF,variable,turbo,n_gradients}
354352 (cumulator1, dcumulator1, complete) = eval_grad_tree_array (
355- tree. l, n, n_gradients, index_tree. l, cX, operators, Val (variable), Val (turbo)
353+ tree. l, Val ( n_gradients) , index_tree. l, cX, operators, Val (variable), Val (turbo)
356354 )
357355 @return_on_false2 complete cumulator1 dcumulator1
358356 (cumulator2, dcumulator2, complete2) = eval_grad_tree_array (
359- tree. r, n, n_gradients, index_tree. r, cX, operators, Val (variable), Val (turbo)
357+ tree. r, Val ( n_gradients) , index_tree. r, cX, operators, Val (variable), Val (turbo)
360358 )
361359 @return_on_false2 complete2 cumulator1 dcumulator1
362360
0 commit comments