1
- _loadleaf (x) = isleaf (x)
2
- for T in [:Dense , :Diagonal , :Bilinear , :Embedding ,
3
- :Conv , :ConvTranspose , :DepthwiseConv , :CrossCor ]
4
- @eval _loadleaf (:: $T ) = true
5
- end
1
+ isloadleaf (x) = all (Functors. isleaf, Functors. children (x))
6
2
7
- loadto! (x, x̄) = x
8
- loadto! (x:: Zeros , x̄) = x
9
- loadto! (x:: AbstractArray , x̄:: AbstractArray ) = copyto! (x, x̄)
10
- for T in [:Dense , :Bilinear , :Conv , :ConvTranspose , :DepthwiseConv , :CrossCor ]
11
- @eval begin
12
- function loadto! (m:: $T , m̄:: $T )
13
- if (size (m. weight) != size (m̄. weight)) || (size (m. bias) != size (m̄. bias))
14
- throw (DimensionMismatch (" Tried to load $m̄ into $m but the parameter sizes do not match." ))
15
- else
16
- return fmap (loadto!, m, m̄)
17
- end
18
- end
19
- loadto! (m:: $T , m̄) = throw (ArgumentError (" Tried to load $m̄ into $m ." ))
20
- end
3
+ loadnumeric! (x, x̄, err) = x
4
+ loadnumeric! (x:: Zeros , x̄, err) = x
5
+ function loadnumeric! (x:: AbstractArray , x̄:: AbstractArray , err)
6
+ (size (x) == size (x̄)) || throw (err)
7
+ copyto! (x, x̄)
21
8
end
22
- function loadto! (m:: Diagonal , m̄:: Diagonal )
23
- if (size (m. α) != size (m̄. α)) || (size (m. β) != size (m̄. β))
24
- throw (DimensionMismatch (" Tried to load $m̄ into $m but the parameter sizes do not match." ))
25
- else
26
- return fmap (loadto!, m, m̄)
27
- end
9
+
10
+ function _loadto! (m, m̄)
11
+ ls, _ = functor (m)
12
+ l̄s, _ = functor (m̄)
13
+ (keys (ls) == keys (l̄s)) ||
14
+ throw (ArgumentError (" Tried to load $m̄ into $m but the structures do not match." ))
15
+
16
+ err = DimensionMismatch (" Tried to load $m̄ into $m but the parameter sizes do not match." )
17
+ foreach ((l, l̄) -> loadnumeric! (l, l̄, err), ls, l̄s)
18
+
19
+ return m
28
20
end
29
- loadto! (m:: Diagonal , m̄) = throw (ArgumentError (" Tried to load $m̄ into $m ." ))
30
- function loadto! (m:: Embedding , m̄:: Embedding )
31
- if size (m. weight) != size (m̄. weight)
32
- throw (DimensionMismatch (" Tried to load $m̄ into $m but the parameter sizes do not match." ))
33
- else
34
- return fmap (loadto!, m, m̄)
35
- end
21
+ function loadto! (m:: T , m̄:: S ) where {T, S}
22
+ (nameof (T) == nameof (S)) || throw (ArgumentError (" Tried to load $m̄ into $m ." ))
23
+ _loadto! (m, m̄)
36
24
end
37
- loadto! (m:: Embedding , m̄) = throw (ArgumentError (" Tried to load $m̄ into $m ." ))
38
25
39
26
function loadmodel! (m, xs:: Params )
40
27
for (p, x) in zip (params (m), xs)
@@ -44,4 +31,4 @@ function loadmodel!(m, xs::Params)
44
31
end
45
32
end
46
33
loadmodel! (m, xs:: AbstractVector ) = loadmodel! (m, params (xs))
47
- loadmodel! (m, m̄) = fmap (loadto!, m, m̄; exclude = _loadleaf )
34
+ loadmodel! (m, m̄) = fmap (loadto!, m, m̄; exclude = isloadleaf )
0 commit comments