@@ -379,20 +379,27 @@ end
379
379
end
380
380
381
381
@testset " loadmodel!(m, m̄)" begin
382
- import Flux: loadmodel!
382
+ import Flux: loadmodel!, Zeros
383
383
384
384
m1 = Chain (Dense (10 , 5 ), Dense (5 , 2 , relu))
385
385
m2 = Chain (Dense (10 , 5 ), Dense (5 , 2 ))
386
386
m3 = Chain (Conv ((3 , 3 ), 3 => 16 ), Dense (5 , 2 ))
387
387
m4 = Chain (Dense (10 , 6 ), Dense (6 , 2 ))
388
+ m5 = Chain (Dense (10 , 5 ), Parallel (+ , Dense (Flux. ones32 (5 , 2 ), Zeros ()), Dense (5 , 2 )))
389
+ m6 = Chain (Dense (10 , 5 ), Parallel (+ , Dense (5 , 2 ), Dense (5 , 2 )))
388
390
389
391
loadmodel! (m1, m2)
390
392
@test m1[1 ]. weight == m2[1 ]. weight
391
393
@test m1[1 ]. bias == m2[1 ]. bias
392
394
@test m1[2 ]. σ == relu
395
+ loadmodel! (m5, m6)
396
+ @test m5[1 ]. weight == m6[1 ]. weight
397
+ @test m5[2 ][1 ]. weight == m6[2 ][1 ]. weight
398
+ @test m5[2 ][1 ]. bias == Zeros ()
393
399
394
400
@test_throws ArgumentError loadmodel! (m1, m3)
395
401
@test_throws DimensionMismatch loadmodel! (m1, m4)
402
+ @test_throws ArgumentError loadmodel! (m1, m5)
396
403
end
397
404
398
405
@testset " destructure" begin
0 commit comments