Skip to content

Commit a155a44

Browse files
committed
Fix tests
1 parent 0790f24 commit a155a44

File tree

2 files changed

+23
-20
lines changed

2 files changed

+23
-20
lines changed

src/loading.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,29 @@ By default, use `copyto!` when `x` and `x̄` are arrays.
1818
Otherwise, just return `x`.
1919
"""
2020
loadleaf!(x, x̄, err) = x
21-
loadleaf!(x::Zeros, x̄, err) = x
21+
function loadleaf!(x::AbstractArray, x̄, err)
22+
x .=
23+
return x
24+
end
2225
function loadleaf!(x::AbstractArray, x̄::AbstractArray, err)
23-
(size(x) == size(x̄)) || throw(err)
24-
copyto!(x, x̄)
26+
(size(x) == size(x̄)) || throw(err)
27+
copyto!(x, x̄)
2528
end
2629

2730
function _loadto!(m, m̄)
28-
ls, _ = functor(m)
29-
l̄s, _ = functor(m̄)
30-
(keys(ls) == keys(l̄s)) ||
31-
throw(ArgumentError("Tried to load $m̄ into $m but the structures do not match."))
31+
ls, _ = functor(m)
32+
l̄s, _ = functor(m̄)
33+
(keys(ls) == keys(l̄s)) ||
34+
throw(ArgumentError("Tried to load $m̄ into $m but the structures do not match."))
3235

33-
err = DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match.")
34-
foreach((l, l̄) -> loadleaf!(l, l̄, err), ls, l̄s)
36+
err = DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match.")
37+
foreach((l, l̄) -> loadleaf!(l, l̄, err), ls, l̄s)
3538

36-
return m
39+
return m
3740
end
3841
function loadto!(m::T, m̄::S) where {T, S}
39-
(nameof(T) == nameof(S)) || throw(ArgumentError("Tried to load $m̄ into $m."))
40-
_loadto!(m, m̄)
42+
(nameof(T) == nameof(S)) || throw(ArgumentError("Tried to load $m̄ into $m."))
43+
_loadto!(m, m̄)
4144
end
4245

4346
"""

test/utils.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Flux
22
using Flux: throttle, nfan, glorot_uniform, glorot_normal,
33
kaiming_normal, kaiming_uniform, orthogonal, truncated_normal,
44
sparse_init, identity_init, stack, unstack, batch, unbatch,
5-
unsqueeze, params, loadparams!
5+
unsqueeze, params, loadmodel!
66
using StatsBase: var, std
77
using Statistics, LinearAlgebra
88
using Random
@@ -366,14 +366,14 @@ end
366366
@test_skip typeof(l1.bias) === typeof(l2.bias)
367367
end
368368

369-
@testset "loadparams!" begin
369+
@testset "loadmodel!" begin
370370
pars(w, b) = [w, b]
371371
pars(l) = pars(l.weight, l.bias)
372372
pararray(m) = mapreduce(pars, vcat, m)
373373
weights(m) = mapreduce(l -> [l.weight], vcat, m)
374374
@testset "Bias type $bt" for bt in (Flux.zeros32, nobias)
375375
m = dm(bt)
376-
loadmodel!(m, params(m))
376+
Flux.loadmodel!(m, params(m))
377377
testdense(m, bt)
378378
end
379379
end
@@ -421,22 +421,22 @@ end
421421
end
422422
end
423423

424-
@testset "loadparams! & absent bias" begin
424+
@testset "loadmodel! & absent bias" begin
425425
m0 = Chain(Dense(2 => 3; bias=false, init = Flux.ones32), Dense(3 => 1))
426426
m1 = Chain(Dense(2 => 3; bias = Flux.randn32(3)), Dense(3 => 1))
427427
m2 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1))
428428

429-
Flux.loadparams!(m1, Flux.params(m2))
429+
Flux.loadmodel!(m1, m2)
430430
@test m1[1].bias == 7:9
431431
@test sum(m1[1].weight) == 21
432432

433433
# load from a model without bias -- should ideally recognise the `false` but `Params` doesn't store it
434-
@test_broken Flux.loadparams!(m1, Flux.params(m0))
435-
@test_broken iszero(m1[1].bias)
434+
m1 = Flux.loadmodel!(m1, m0)
435+
@test iszero(m1[1].bias)
436436
@test sum(m1[1].weight) == 6 # written before error
437437

438438
# load into a model without bias -- should it ignore the parameter which has no home, or error?
439-
@test_broken Flux.loadparams!(m0, Flux.params(m2))
439+
m0 = Flux.loadmodel!(m0, m2)
440440
@test iszero(m0[1].bias) # obviously unchanged
441441
@test sum(m0[1].weight) == 21
442442
end

0 commit comments

Comments
 (0)