Skip to content

Commit cd06023

Browse files
committed
Add initial implementation
1 parent 5f17f1c commit cd06023

File tree

6 files changed

+69
-10
lines changed

6 files changed

+69
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Adapt = "3.0"
2929
ArrayInterface = "3.1, 4, 5"
3030
CUDA = "3"
3131
ChainRulesCore = "1.12"
32-
Functors = "0.2.1"
32+
Functors = "0.2.8"
3333
MLUtils = "0.2"
3434
MacroTools = "0.5"
3535
NNlib = "0.8.2"

src/Flux.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ include("layers/normalise.jl")
4646
include("layers/upsample.jl")
4747
include("layers/show.jl")
4848

49+
include("loading.jl")
50+
4951
include("outputsize.jl")
5052

5153
include("data/Data.jl")

src/deprecations.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ zeros(T::Type, dims...) = Base.zeros(T, dims...)
1515
ones32(::Type, dims...) = throw(ArgumentError("Flux.ones32 is always Float32, use Base.ones to specify the element type"))
1616
zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32, use Base.zeros to specify the element type"))
1717

18+
@deprecate loadparams!(m, xs) loadmodel!(m, xs)
19+
1820
# v0.13 deprecations
1921

2022
function Broadcast.broadcasted(f::Recur, args...)

src/functor.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,6 @@ function params(m...)
8585
return ps
8686
end
8787

88-
function loadparams!(m, xs)
89-
for (p, x) in zip(params(m), xs)
90-
size(p) == size(x) ||
91-
error("Expected param size $(size(p)), got $(size(x))")
92-
copyto!(p, x)
93-
end
94-
end
95-
9688
struct FluxCUDAAdaptor end
9789
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
9890
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))

src/loading.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
_loadleaf(x) = isleaf(x)
2+
for T in [:Dense, :Diagonal, :Bilinear, :Embedding,
3+
:Conv, :ConvTranspose, :DepthwiseConv, :CrossCor]
4+
@eval _loadleaf(::$T) = true
5+
end
6+
7+
loadto!(x, x̄) = x
8+
loadto!(x::AbstractArray, x̄::AbstractArray) = copyto!(x, x̄)
9+
for T in [:Dense, :Bilinear, :Conv, :ConvTranspose, :DepthwiseConv, :CrossCor]
10+
@eval begin
11+
function loadto!(m::$T, m̄::$T)
12+
if (size(m.weight) != size(m̄.weight)) || (size(m.bias) != size(m̄.bias))
13+
throw(DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match."))
14+
else
15+
return fmap(loadto!, m, m̄)
16+
end
17+
end
18+
loadto!(m::$T, m̄) = throw(ArgumentError("Tried to load $m̄ into $m."))
19+
end
20+
end
21+
function loadto!(m::Diagonal, m̄::Diagonal)
22+
if (size(m.α) != size(m̄.α)) || (size(m.β) != size(m̄.β))
23+
throw(DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match."))
24+
else
25+
return fmap(loadto!, m, m̄)
26+
end
27+
end
28+
loadto!(m::Diagonal, m̄) = throw(ArgumentError("Tried to load $m̄ into $m."))
29+
function loadto!(m::Embedding, m̄::Embedding)
30+
if size(m.weight) != size(m̄.weight)
31+
throw(DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match."))
32+
else
33+
return fmap(loadto!, m, m̄)
34+
end
35+
end
36+
loadto!(m::Embedding, m̄) = throw(ArgumentError("Tried to load $m̄ into $m."))
37+
38+
function loadmodel!(m, xs::Params)
39+
for (p, x) in zip(params(m), xs)
40+
size(p) == size(x) ||
41+
error("Expected param size $(size(p)), got $(size(x))")
42+
copyto!(p, x)
43+
end
44+
end
45+
loadmodel!(m, xs::AbstractVector) = loadmodel!(m, params(xs))
46+
loadmodel!(m, m̄) = fmap(loadto!, m, m̄; exclude = _loadleaf)

test/utils.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,11 +373,28 @@ end
373373
weights(m) = mapreduce(l -> [l.weight], vcat, m)
374374
@testset "Bias type $bt" for bt in (Flux.zeros32, nobias)
375375
m = dm(bt)
376-
loadparams!(m, params(m))
376+
loadmodel!(m, params(m))
377377
testdense(m, bt)
378378
end
379379
end
380380

381+
@testset "loadmodel!(m, m̄)" begin
382+
import Flux: loadmodel!
383+
384+
m1 = Chain(Dense(10, 5), Dense(5, 2, relu))
385+
m2 = Chain(Dense(10, 5), Dense(5, 2))
386+
m3 = Chain(Conv((3, 3), 3 => 16), Dense(5, 2))
387+
m4 = Chain(Dense(10, 6), Dense(6, 2))
388+
389+
loadmodel!(m1, m2)
390+
@test m1[1].weight == m2[1].weight
391+
@test m1[1].bias == m2[1].bias
392+
@test m1[2].σ == relu
393+
394+
@test_throws ArgumentError loadmodel!(m1, m3)
395+
@test_throws DimensionMismatch loadmodel!(m1, m4)
396+
end
397+
381398
@testset "destructure" begin
382399
import Flux: destructure
383400
@testset "Bias type $bt" for bt in (zeros, nobias)

0 commit comments

Comments
 (0)