Skip to content

Commit 694bb9b

Browse files
authored
Allow destructure to make an empty vector (#68)
* allow destructure to make an empty vector * bump version, rm undef exports * change to Bool[]
1 parent dd571ca commit 694bb9b

File tree

4 files changed

+24
-2
lines changed

4 files changed

+24
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Optimisers"
22
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
33
authors = ["Mike J Innes <[email protected]>"]
4-
version = "0.2.2"
4+
version = "0.2.3"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/Optimisers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using LinearAlgebra
66
include("interface.jl")
77

88
include("destructure.jl")
9-
export destructure, total, total2
9+
export destructure
1010

1111
include("rules.jl")
1212
export Descent, ADAM, Momentum, Nesterov, RMSProp,

src/destructure.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ function _flatten(x)
6464
len[] = o + length(y)
6565
o
6666
end
67+
isempty(arrays) && return Bool[], off, 0
6768
reduce(vcat, arrays), off, len[]
6869
end
6970

test/destructure.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,3 +220,24 @@ tmp1
220220
@test tmp2 isa Vector{<:ForwardDiff.Dual}
221221
222222
=#
223+
224+
@testset "empty, issue 67" begin
225+
m0 = (nothing, missing, isempty)
226+
@test destructure(m0)[1] isa Vector{<:Real}
227+
v0, re0 = destructure(m0)
228+
@test re0(Float32[]) === m0
229+
@test_throws DimensionMismatch re0([1])
230+
231+
# This is an elaborate way of checking that it doesn't cause promotions, even of small floats:
232+
m01 = [(x=nothing, y=0), (x=Float16[1, 2], y=Float16[3])]
233+
v01, _ = destructure(m01)
234+
v012 = vcat(destructure(m01[1])[1], destructure(m01[2])[1])
235+
@test v01 == v012
236+
@test v012 isa Vector{Float16}
237+
238+
y, bk = Zygote.pullback(x -> sum(destructure(x)[1]), ("a", :beta))
239+
@test bk(1.0) == (nothing,)
240+
# Zygote regards 3,4 as differentiable, but Optimisers does not regard them as parameters:
241+
y, bk = Zygote.pullback(x -> sum(destructure(x)[1]), (3, 4))
242+
@test bk(1.0) == (nothing,)
243+
end

0 commit comments

Comments
 (0)