Skip to content

Commit abeaf3d

Browse files
Merge pull request #221 from frankschae/ODEProblem_u0
Differentiating ArrayPartitions in DEProblems
2 parents fd72df9 + ddff890 commit abeaf3d

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

src/zygote.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function ChainRulesCore.rrule(::Type{<:ArrayPartition}, x::S, ::Type{Val{copy_x}
1919
function ArrayPartition_adjoint(_y)
2020
y = Array(_y)
2121
starts = vcat(0,cumsum(reduce(vcat,length.(x))))
22-
NoTangent(), ArrayPartition(ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i]))), length(x)), NoTangent()
22+
NoTangent(), ntuple(i -> reshape(y[starts[i]+1:starts[i+1]], size(x[i])), length(x)), NoTangent()
2323
end
2424

2525
ArrayPartition(x, Val{copy_x}), ArrayPartition_adjoint

test/adjoints.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using RecursiveArrayTools, Zygote, ForwardDiff, Test
2+
using OrdinaryDiffEq
23

34
function loss(x)
45
sum(abs2,Array(VectorOfArray([x .* i for i in 1:5])))
@@ -30,10 +31,17 @@ function loss5(x)
3031
sum(abs2,Array(ArrayPartition([x .* i for i in 1:5]...)))
3132
end
3233

34+
function loss6(x)
35+
_x = ArrayPartition([x .* i for i in 1:5]...)
36+
_prob = ODEProblem((u,p,t)->u, _x, (0,1))
37+
sum(abs2, Array(_prob.u0))
38+
end
39+
3340
x = float.(6:10)
3441
loss(x)
3542
@test Zygote.gradient(loss,x)[1] == ForwardDiff.gradient(loss,x)
3643
@test Zygote.gradient(loss2,x)[1] == ForwardDiff.gradient(loss2,x)
3744
@test Zygote.gradient(loss3,x)[1] == ForwardDiff.gradient(loss3,x)
3845
@test Zygote.gradient(loss4,x)[1] == ForwardDiff.gradient(loss4,x)
3946
@test Zygote.gradient(loss5,x)[1] == ForwardDiff.gradient(loss5,x)
47+
@test Zygote.gradient(loss6,x)[1] == ForwardDiff.gradient(loss6,x)

0 commit comments

Comments
 (0)