Skip to content

Commit 99a18ec

Browse files
committed
Add more tests
1 parent cd06023 commit 99a18ec

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

src/loading.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ for T in [:Dense, :Diagonal, :Bilinear, :Embedding,
55
end
66

77
loadto!(x, x̄) = x
8+
loadto!(x::Zeros, x̄) = x
89
loadto!(x::AbstractArray, x̄::AbstractArray) = copyto!(x, x̄)
910
for T in [:Dense, :Bilinear, :Conv, :ConvTranspose, :DepthwiseConv, :CrossCor]
1011
@eval begin

test/utils.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,20 +379,27 @@ end
379379
end
380380

381381
@testset "loadmodel!(m, m̄)" begin
382-
import Flux: loadmodel!
382+
import Flux: loadmodel!, Zeros
383383

384384
m1 = Chain(Dense(10, 5), Dense(5, 2, relu))
385385
m2 = Chain(Dense(10, 5), Dense(5, 2))
386386
m3 = Chain(Conv((3, 3), 3 => 16), Dense(5, 2))
387387
m4 = Chain(Dense(10, 6), Dense(6, 2))
388+
m5 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(5, 2), Zeros()), Dense(5, 2)))
389+
m6 = Chain(Dense(10, 5), Parallel(+, Dense(5, 2), Dense(5, 2)))
388390

389391
loadmodel!(m1, m2)
390392
@test m1[1].weight == m2[1].weight
391393
@test m1[1].bias == m2[1].bias
392394
@test m1[2].σ == relu
395+
loadmodel!(m5, m6)
396+
@test m5[1].weight == m6[1].weight
397+
@test m5[2][1].weight == m6[2][1].weight
398+
@test m5[2][1].bias == Zeros()
393399

394400
@test_throws ArgumentError loadmodel!(m1, m3)
395401
@test_throws DimensionMismatch loadmodel!(m1, m4)
402+
@test_throws ArgumentError loadmodel!(m1, m5)
396403
end
397404

398405
@testset "destructure" begin

0 commit comments

Comments
 (0)