Skip to content

Commit dee5842

Browse files
committed
Refactor to allow better support for loading errors with custom models
1 parent 492c34e commit dee5842

File tree

1 file changed

+21
-34
lines changed

1 file changed

+21
-34
lines changed

src/loading.jl

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,27 @@
1-
_loadleaf(x) = isleaf(x)
2-
for T in [:Dense, :Diagonal, :Bilinear, :Embedding,
3-
:Conv, :ConvTranspose, :DepthwiseConv, :CrossCor]
4-
@eval _loadleaf(::$T) = true
5-
end
1+
isloadleaf(x) = all(Functors.isleaf, Functors.children(x))
62

7-
loadto!(x, x̄) = x
8-
loadto!(x::Zeros, x̄) = x
9-
loadto!(x::AbstractArray, x̄::AbstractArray) = copyto!(x, x̄)
10-
for T in [:Dense, :Bilinear, :Conv, :ConvTranspose, :DepthwiseConv, :CrossCor]
11-
@eval begin
12-
function loadto!(m::$T, m̄::$T)
13-
if (size(m.weight) != size(m̄.weight)) || (size(m.bias) != size(m̄.bias))
14-
throw(DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match."))
15-
else
16-
return fmap(loadto!, m, m̄)
17-
end
18-
end
19-
loadto!(m::$T, m̄) = throw(ArgumentError("Tried to load $m̄ into $m."))
20-
end
3+
loadnumeric!(x, x̄, err) = x
4+
loadnumeric!(x::Zeros, x̄, err) = x
5+
function loadnumeric!(x::AbstractArray, x̄::AbstractArray, err)
6+
(size(x) == size(x̄)) || throw(err)
7+
copyto!(x, x̄)
218
end
22-
function loadto!(m::Diagonal, m̄::Diagonal)
23-
if (size(m.α) != size(m̄.α)) || (size(m.β) != size(m̄.β))
24-
throw(DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match."))
25-
else
26-
return fmap(loadto!, m, m̄)
27-
end
9+
10+
function _loadto!(m, m̄)
11+
ls, _ = functor(m)
12+
l̄s, _ = functor(m̄)
13+
(keys(ls) == keys(l̄s)) ||
14+
throw(ArgumentError("Tried to load $m̄ into $m but the structures do not match."))
15+
16+
err = DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match.")
17+
foreach((l, l̄) -> loadnumeric!(l, l̄, err), ls, l̄s)
18+
19+
return m
2820
end
29-
loadto!(m::Diagonal, m̄) = throw(ArgumentError("Tried to load $m̄ into $m."))
30-
function loadto!(m::Embedding, m̄::Embedding)
31-
if size(m.weight) != size(m̄.weight)
32-
throw(DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match."))
33-
else
34-
return fmap(loadto!, m, m̄)
35-
end
21+
function loadto!(m::T, m̄::S) where {T, S}
22+
(nameof(T) == nameof(S)) || throw(ArgumentError("Tried to load $m̄ into $m."))
23+
_loadto!(m, m̄)
3624
end
37-
loadto!(m::Embedding, m̄) = throw(ArgumentError("Tried to load $m̄ into $m."))
3825

3926
function loadmodel!(m, xs::Params)
4027
for (p, x) in zip(params(m), xs)
@@ -44,4 +31,4 @@ function loadmodel!(m, xs::Params)
4431
end
4532
end
4633
loadmodel!(m, xs::AbstractVector) = loadmodel!(m, params(xs))
47-
loadmodel!(m, m̄) = fmap(loadto!, m, m̄; exclude = _loadleaf)
34+
loadmodel!(m, m̄) = fmap(loadto!, m, m̄; exclude = isloadleaf)

0 commit comments

Comments
 (0)