Skip to content

Commit 3ee0747

Browse files
authored
Merge pull request #2194 from jonathanBieler/master
fixed BSON loadmodel! documentation and added a test case
2 parents cebc0d9 + 3b8c61a commit 3ee0747

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
4848
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
4949
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5050
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
51+
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
5152

5253
[targets]
53-
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
54+
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON"]

docs/src/saving.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ This can be done as
7676

7777
```julia
7878
using Flux: loadmodel!
79-
using BSON: @load
79+
using BSON
8080

8181
# some predefined model
8282
model = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax)
8383

8484
# load one model into another
85-
model = loadmodel!(model, @load("mymodel.bson"))
85+
model = loadmodel!(model, BSON.load("mymodel.bson")[:model])
8686
```
8787

8888
This ensures that the model loaded from `"mymodel.bson"` matches the structure of `model`. [`Flux.loadmodel!`](@ref) is also convenient for copying parameters between models in memory.

test/utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using StatsBase: var, std
88
using Statistics, LinearAlgebra
99
using Random
1010
using Test
11+
using BSON
1112

1213
@testset "Throttle" begin
1314
@testset "default behaviour" begin
@@ -560,6 +561,17 @@ end
560561
end
561562
end
562563

564+
@testset "loadmodel!(dst, src) with BSON" begin
565+
m1 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1))
566+
m2 = Chain(Dense(Float32[0 0; 0 0; 0 0], Float32[0, 0, 0]), Dense(3 => 1))
567+
@test m1[1].weight != m2[1].weight
568+
mktempdir() do dir
569+
BSON.@save joinpath(dir, "test.bson") m1
570+
m2 = Flux.loadmodel!(m2, BSON.load(joinpath(dir, "test.bson"))[:m1])
571+
@test m1[1].weight == m2[1].weight
572+
end
573+
end
574+
563575
@testset "destructure" begin
564576
import Flux: destructure
565577
@testset "Bias type $bt" for bt in (zeros, nobias)

0 commit comments

Comments
 (0)