Skip to content

Commit 0e70295

Browse files
Merge pull request #59 from mcabbott/vecbug
Missing `vec` in gradient of `destructure`
2 parents 1e34fa2 + 140499e commit 0e70295

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Optimisers"
22
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
33
authors = ["Mike J Innes <[email protected]>"]
4-
version = "0.2.0"
4+
version = "0.2.1"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/destructure.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ function _grad!(x, dx, off, flat::AbstractVector)
131131
flat
132132
end
133133
function _grad!(x, dx, off::Integer, flat::AbstractVector)
134-
@views flat[off .+ (1:length(x))] .+= dx # must visit all tied nodes
134+
@views flat[off .+ (1:length(x))] .+= vec(dx) # must visit all tied nodes
135135
flat
136136
end
137137
_grad!(x, dx::Zero, off, flat::AbstractVector) = dx

test/destructure.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,17 @@
22
m1 = collect(1:3.0)
33
m2 = (collect(1:3.0), collect(4:6.0))
44
m3 = (x = m1, y = sin, z = collect(4:6.0))
5+
56
m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied
67
m5 = (a = (m3, true), b = (m1, false), c = (m4, true))
78
m6 = (a = m1, b = [4.0 + im], c = m1)
9+
810
m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0)))
911
m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]]
1012

13+
mat = Float32[4 6; 5 7]
14+
m9 = (a = m1, b = mat, c = [mat, m1])
15+
1116
@testset "flatten & rebuild" begin
1217
@test destructure(m1)[1] isa Vector{Float64}
1318
@test destructure(m1)[1] == 1:3
@@ -16,6 +21,7 @@ m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]]
1621
@test destructure(m4)[1] == 1:6
1722
@test destructure(m5)[1] == vcat(1:6, 4:6)
1823
@test destructure(m6)[1] == vcat(1:3, 4 + im)
24+
@test destructure(m9)[1] == 1:7
1925

2026
@test destructure(m1)[2](7:9) == [7,8,9]
2127
@test destructure(m2)[2](4:9) == ([4,5,6], [7,8,9])
@@ -45,6 +51,10 @@ m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]]
4551
@test m8′[2].b.y === false
4652
@test m8′[3][1] == [5.0]
4753

54+
m9′ = destructure(m9)[2](10:10:70)
55+
@test m9′.b === m9′.c[1]
56+
@test m9′.b isa Matrix{Float32}
57+
4858
# errors
4959
@test_throws Exception destructure(m7)[2]([10,20])
5060
@test_throws Exception destructure(m7)[2]([10,20,30,40])
@@ -71,6 +81,9 @@ end
7181
@test g8[2].b.x == [8]
7282
@test g8[3] == [[10.0]]
7383

84+
g9 = gradient(m -> sum(sqrt, destructure(m)[1]), m9)[1]
85+
@test g9.c === nothing
86+
7487
@testset "second derivative" begin
7588
@test gradient([1,2,3.0]) do v
7689
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1])
@@ -119,6 +132,9 @@ end
119132
@test gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0]
120133
@test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10]
121134

135+
re9 = destructure(m9)[2]
136+
@test gradient(x -> sum(abs2, re9(x).c[1]), 1:7)[1] == [0,0,0, 8,10,12,14]
137+
122138
@testset "second derivative" begin
123139
@test_broken gradient(collect(1:6.0)) do y
124140
sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1])

0 commit comments

Comments
 (0)