Skip to content

Commit 706728e

Browse files
committed
Ensure that derivatives preserve containers
1 parent f4114df commit 706728e

File tree

2 files changed

+66
-50
lines changed

2 files changed

+66
-50
lines changed

src/EvaluateEquationDerivative.jl

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module EvaluateEquationDerivativeModule
33
import LoopVectorization: indices, @turbo
44
import ..EquationModule: Node
55
import ..OperatorEnumModule: OperatorEnum
6-
import ..UtilsModule: @return_on_false2, @maybe_turbo, is_bad_array
6+
import ..UtilsModule: @return_on_false2, @maybe_turbo, is_bad_array, fill_similar
77
import ..EquationUtilsModule: count_constants, index_constants, NodeIndex
88
import ..EvaluateEquationModule: deg0_eval
99

@@ -61,7 +61,7 @@ function eval_diff_tree_array(
6161
T = promote_type(T1, T2)
6262
@warn "Warning: eval_diff_tree_array received mixed types: tree=$(T1) and data=$(T2)."
6363
tree = convert(Node{T}, tree)
64-
cX = convert(AbstractMatrix{T}, cX)
64+
cX = T.(cX)
6565
return eval_diff_tree_array(tree, cX, operators, direction; turbo=turbo)
6666
end
6767

@@ -102,10 +102,12 @@ end
102102
function diff_deg0_eval(
103103
tree::Node{T}, cX::AbstractMatrix{T}, direction::Int
104104
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number}
105-
n = size(cX, 2)
106105
const_part = deg0_eval(tree, cX)[1]
107-
derivative_part =
108-
((!tree.constant) && tree.feature == direction) ? ones(T, n) : zeros(T, n)
106+
derivative_part = if ((!tree.constant) && tree.feature == direction)
107+
fill_similar(one(T), cX, axes(cX, 2))
108+
else
109+
fill_similar(zero(T), cX, axes(cX, 2))
110+
end
109111
return (const_part, derivative_part, true)
110112
end
111113

@@ -118,7 +120,6 @@ function diff_deg1_eval(
118120
direction::Int,
119121
::Val{turbo},
120122
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number,F,dF,turbo}
121-
n = size(cX, 2)
122123
(cumulator, dcumulator, complete) = _eval_diff_tree_array(
123124
tree.l, cX, operators, direction, Val(turbo)
124125
)
@@ -196,17 +197,11 @@ function eval_grad_tree_array(
196197
turbo::Bool=false,
197198
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number}
198199
assert_autodiff_enabled(operators)
199-
n = size(cX, 2)
200-
if variable
201-
n_gradients = size(cX, 1)
202-
else
203-
n_gradients = count_constants(tree)
204-
end
200+
n_gradients = variable ? size(cX, 1) : count_constants(tree)
205201
index_tree = index_constants(tree, 0)
206202
return eval_grad_tree_array(
207203
tree,
208-
n,
209-
n_gradients,
204+
Val(n_gradients),
210205
index_tree,
211206
cX,
212207
operators,
@@ -217,16 +212,17 @@ end
217212

218213
function eval_grad_tree_array(
219214
tree::Node{T},
220-
n::Int,
221-
n_gradients::Int,
215+
::Val{n_gradients},
222216
index_tree::NodeIndex,
223217
cX::AbstractMatrix{T},
224218
operators::OperatorEnum,
225219
::Val{variable},
226220
::Val{turbo},
227-
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number,variable,turbo}
221+
)::Tuple{
222+
AbstractVector{T},AbstractMatrix{T},Bool
223+
} where {T<:Number,variable,turbo,n_gradients}
228224
evaluation, gradient, complete = _eval_grad_tree_array(
229-
tree, n, n_gradients, index_tree, cX, operators, Val(variable), Val(turbo)
225+
tree, Val(n_gradients), index_tree, cX, operators, Val(variable), Val(turbo)
230226
)
231227
@return_on_false2 complete evaluation gradient
232228
return evaluation, gradient, !(is_bad_array(evaluation) || is_bad_array(gradient))
@@ -251,21 +247,21 @@ end
251247

252248
function _eval_grad_tree_array(
253249
tree::Node{T},
254-
n::Int,
255-
n_gradients::Int,
250+
::Val{n_gradients},
256251
index_tree::NodeIndex,
257252
cX::AbstractMatrix{T},
258253
operators::OperatorEnum,
259254
::Val{variable},
260255
::Val{turbo},
261-
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number,variable,turbo}
256+
)::Tuple{
257+
AbstractVector{T},AbstractMatrix{T},Bool
258+
} where {T<:Number,variable,turbo,n_gradients}
262259
if tree.degree == 0
263-
grad_deg0_eval(tree, n, n_gradients, index_tree, cX, Val(variable))
260+
grad_deg0_eval(tree, Val(n_gradients), index_tree, cX, Val(variable))
264261
elseif tree.degree == 1
265262
grad_deg1_eval(
266263
tree,
267-
n,
268-
n_gradients,
264+
Val(n_gradients),
269265
index_tree,
270266
cX,
271267
operators.unaops[tree.op],
@@ -277,8 +273,7 @@ function _eval_grad_tree_array(
277273
else
278274
grad_deg2_eval(
279275
tree,
280-
n,
281-
n_gradients,
276+
Val(n_gradients),
282277
index_tree,
283278
cX,
284279
operators.binops[tree.op],
@@ -292,38 +287,40 @@ end
292287

293288
function grad_deg0_eval(
294289
tree::Node{T},
295-
n::Int,
296-
n_gradients::Int,
290+
::Val{n_gradients},
297291
index_tree::NodeIndex,
298292
cX::AbstractMatrix{T},
299293
::Val{variable},
300-
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number,variable}
294+
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number,variable,n_gradients}
301295
const_part = deg0_eval(tree, cX)[1]
302296

297+
zero_mat = hcat((fill_similar(zero(T), cX, axes(cX, 2)) for _ in 1:n_gradients)...)'
298+
303299
if variable == tree.constant
304-
return (const_part, zeros(T, n_gradients, n), true)
300+
return (const_part, zero_mat, true)
305301
end
306302

307303
index = variable ? tree.feature : index_tree.constant_index
308-
derivative_part = zeros(T, n_gradients, n)
304+
derivative_part = zero_mat
309305
derivative_part[index, :] .= one(T)
310306
return (const_part, derivative_part, true)
311307
end
312308

313309
function grad_deg1_eval(
314310
tree::Node{T},
315-
n::Int,
316-
n_gradients::Int,
311+
::Val{n_gradients},
317312
index_tree::NodeIndex,
318313
cX::AbstractMatrix{T},
319314
op::F,
320315
diff_op::dF,
321316
operators::OperatorEnum,
322317
::Val{variable},
323318
::Val{turbo},
324-
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number,F,dF,variable,turbo}
319+
)::Tuple{
320+
AbstractVector{T},AbstractMatrix{T},Bool
321+
} where {T<:Number,F,dF,variable,turbo,n_gradients}
325322
(cumulator, dcumulator, complete) = eval_grad_tree_array(
326-
tree.l, n, n_gradients, index_tree.l, cX, operators, Val(variable), Val(turbo)
323+
tree.l, Val(n_gradients), index_tree.l, cX, operators, Val(variable), Val(turbo)
327324
)
328325
@return_on_false2 complete cumulator dcumulator
329326

@@ -341,22 +338,23 @@ end
341338

342339
function grad_deg2_eval(
343340
tree::Node{T},
344-
n::Int,
345-
n_gradients::Int,
341+
::Val{n_gradients},
346342
index_tree::NodeIndex,
347343
cX::AbstractMatrix{T},
348344
op::F,
349345
diff_op::dF,
350346
operators::OperatorEnum,
351347
::Val{variable},
352348
::Val{turbo},
353-
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number,F,dF,variable,turbo}
349+
)::Tuple{
350+
AbstractVector{T},AbstractMatrix{T},Bool
351+
} where {T<:Number,F,dF,variable,turbo,n_gradients}
354352
(cumulator1, dcumulator1, complete) = eval_grad_tree_array(
355-
tree.l, n, n_gradients, index_tree.l, cX, operators, Val(variable), Val(turbo)
353+
tree.l, Val(n_gradients), index_tree.l, cX, operators, Val(variable), Val(turbo)
356354
)
357355
@return_on_false2 complete cumulator1 dcumulator1
358356
(cumulator2, dcumulator2, complete2) = eval_grad_tree_array(
359-
tree.r, n, n_gradients, index_tree.r, cX, operators, Val(variable), Val(turbo)
357+
tree.r, Val(n_gradients), index_tree.r, cX, operators, Val(variable), Val(turbo)
360358
)
361359
@return_on_false2 complete2 cumulator1 dcumulator1
362360

test/test_container_preserved.jl

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,33 @@ using StaticArrays
33
using Test
44

55
@testset "StaticArrays type preserved" begin
6-
X = MMatrix{3,10}(randn(3, 10))
7-
operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])
8-
x1, x2, x3 = (i -> Node(Float64; feature=i)).(1:3)
9-
tree = cos(x1 * 5.2 - 0.9) * x3 + x2 * x2 - 2.2
10-
y = tree(X, operators)
11-
@test typeof(y) == MVector{10,Float64}
6+
for T in (Float32, Float64)
7+
X = MMatrix{3,10}(randn(T, 3, 10))
8+
operators = OperatorEnum(;
9+
binary_operators=[+, -, *, /], unary_operators=[cos, sin], enable_autodiff=true
10+
)
11+
x1, x2, x3 = (i -> Node(; feature=i)).(1:3)
12+
tree = cos(x1 * 5.2 - 0.9) * x3 + x2 * x2 - 2.2 * x1 + 1.0
13+
tree = convert(Node{T}, tree)
1214

13-
X .= NaN
14-
tree = cos(x1 * 5.2 - 0.9) * x3 + x2 * x2 - 2.2
15-
y = tree(X, operators)
16-
@test typeof(y) == MVector{10,Float64}
15+
y = tree(X, operators)
16+
@test typeof(y) == MVector{10,T}
17+
18+
dy = tree'(X, operators)
19+
@test typeof(dy) == MMatrix{3,10,T,30}
20+
21+
dy = tree'(X, operators; variable=false)
22+
@test typeof(dy) == MMatrix{4,10,T,40}
23+
24+
# Even with NaNs:
25+
X .= T(NaN)
26+
y = tree(X, operators)
27+
@test typeof(y) == MVector{10,T}
28+
29+
dy = tree'(X, operators)
30+
@test typeof(dy) == MMatrix{3,10,T,30}
31+
32+
dy = tree'(X, operators; variable=false)
33+
@test typeof(dy) == MMatrix{4,10,T,40}
34+
end
1735
end

0 commit comments

Comments
 (0)