Skip to content

Commit c24e54f

Browse files
Merge pull request #112 from SciML/adjoints
add some more adjoints and test adjoints
2 parents bf52cf3 + 9a94f7a commit c24e54f

File tree

5 files changed

+53
-4
lines changed

5 files changed

+53
-4
lines changed

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveArrayTools"
22
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "2.5.0"
4+
version = "2.6.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -21,11 +21,13 @@ ZygoteRules = "0.2"
2121
julia = "1.3"
2222

2323
[extras]
24+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2425
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
2526
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
27+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2628
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2729
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
28-
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
30+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2931

3032
[targets]
31-
test = ["NLsolve", "OrdinaryDiffEq", "Test", "Unitful", "Random"]
33+
test = ["ForwardDiff", "NLsolve", "OrdinaryDiffEq", "Test", "Unitful", "Random", "Zygote"]

src/zygote.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,11 @@ ZygoteRules.@adjoint function ArrayPartition(x...)
2323
end
2424
ArrayPartition(x...),ArrayPartition_adjoint
2525
end
26+
27+
ZygoteRules.@adjoint function VectorOfArray(u)
28+
VectorOfArray(u),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],)
29+
end
30+
31+
ZygoteRules.@adjoint function DiffEqArray(u,t)
32+
DiffEqArray(u,t),y -> ([y[ntuple(x->Colon(),ndims(y)-1)...,i] for i in 1:size(y)[end]],nothing)
33+
end

test/adjoints.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using RecursiveArrayTools, Zygote, ForwardDiff, Test
2+
3+
function loss(x)
4+
sum(abs2,Array(VectorOfArray([x .* i for i in 1:5])))
5+
end
6+
7+
function loss2(x)
8+
sum(abs2,Array(DiffEqArray([x .* i for i in 1:5],1:5)))
9+
end
10+
11+
function loss3(x)
12+
y = VectorOfArray([x .* i for i in 1:5])
13+
tmp = 0.0
14+
for i in 1:5, j in 1:5
15+
tmp += y[i,j]
16+
end
17+
tmp
18+
end
19+
20+
function loss4(x)
21+
y = DiffEqArray([x .* i for i in 1:5],1:5)
22+
tmp = 0.0
23+
for i in 1:5, j in 1:5
24+
tmp += y[i,j]
25+
end
26+
tmp
27+
end
28+
29+
function loss5(x)
30+
sum(abs2,Array(ArrayPartition([x .* i for i in 1:5]...)))
31+
end
32+
33+
x = float.(6:10)
34+
loss(x)
35+
@test Zygote.gradient(loss,x)[1] == ForwardDiff.gradient(loss,x)
36+
@test Zygote.gradient(loss2,x)[1] == ForwardDiff.gradient(loss2,x)
37+
@test Zygote.gradient(loss3,x)[1] == ForwardDiff.gradient(loss3,x)
38+
@test Zygote.gradient(loss4,x)[1] == ForwardDiff.gradient(loss4,x)
39+
@test Zygote.gradient(loss5,x)[1] == ForwardDiff.gradient(loss5,x)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ using Test
99
@time @testset "StaticArrays Tests" begin include("copy_static_array_test.jl") end
1010
@time @testset "Linear Algebra Tests" begin include("linalg.jl") end
1111
@time @testset "Upstream Tests" begin include("upstream.jl") end
12+
@time @testset "Adjoint Tests" begin include("adjoints.jl") end
1213
end

test/utils_test.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ AofuSA = [@SVector [2.0u"kg",3.0u"kg"] for i in 1:5]
2727
@test recursive_unitless_eltype(AofuSA) == SVector{2,Float64}
2828

2929
A = [ArrayPartition(ones(1),ones(1)),]
30-
@test repr("text/plain", A) == "1-element Array{ArrayPartition{Float64,Tuple{Array{Float64,1},Array{Float64,1}}},1}:\n [1.0][1.0]"
3130

3231
function test_recursive_bottom_eltype()
3332
function test_value(val::Any, expected_type::Type)

0 commit comments

Comments
 (0)