Skip to content

Commit 37a03a8

Browse files
committed
Resolve conflicts
2 parents 1b8b1bf + 674527e commit 37a03a8

File tree

19 files changed

+396
-169
lines changed

19 files changed

+396
-169
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ been removed in favour of MLDatasets.jl.
1111
* Many utily functions and the `DataLoader` are [now provided by MLUtils.jl](https://github.com/FluxML/Flux.jl/pull/1874).
1212
* The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`.
1313
* Added [truncated normal initialisation](https://github.com/FluxML/Flux.jl/pull/1877) of weights.
14+
* The `Flux.Diagonal` layer is now called `Scale`, and accepts an activation function.
15+
* `loadparams!` is replaced by [`loadmodel!`](https://github.com/FluxML/Flux.jl/pull/1875) which copies trainable + non-trainable parameters and performs more thorough structural checking
1416

1517
## v0.12.10
1618
* `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838)

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Flux"
22
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
3-
version = "0.13.0-DEV"
3+
version = "0.13.0"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -29,7 +29,7 @@ Adapt = "3.0"
2929
ArrayInterface = "3.1, 4, 5"
3030
CUDA = "3"
3131
ChainRulesCore = "1.12"
32-
Functors = "0.2.1"
32+
Functors = "0.2.8"
3333
MLUtils = "0.2"
3434
MacroTools = "0.5"
3535
NNlib = "0.8.2"

docs/src/models/layers.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ CrossCor
2525
SamePad
2626
Flux.flatten
2727
Flux.convfilter
28-
Flux.depthwiseconvfilter
2928
```
3029

3130
## Upsampling Layers
@@ -57,7 +56,7 @@ Maxout
5756
SkipConnection
5857
Parallel
5958
Flux.Bilinear
60-
Flux.Diagonal
59+
Flux.Scale
6160
Flux.Embedding
6261
```
6362

docs/src/saving.md

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
You may wish to save models so that they can be loaded and run in a later
44
session. The easiest way to do this is via
5-
[BSON.jl](https://github.com/MikeInnes/BSON.jl).
5+
[BSON.jl](https://github.com/JuliaIO/BSON.jl).
66

77
Save a model:
88

@@ -34,7 +34,6 @@ Chain(
3434
Dense(5 => 2), # 12 parameters
3535
NNlib.softmax,
3636
) # Total: 4 arrays, 67 parameters, 524 bytes.
37-
3837
```
3938

4039
Models are just normal Julia structs, so it's fine to use any Julia storage
@@ -44,15 +43,17 @@ versions of Flux).
4443

4544
!!! note
4645

47-
If a saved model's weights are stored on the GPU, the model will not load
46+
If a saved model's parameters are stored on the GPU, the model will not load
4847
later on if there is no GPU support available. It's best to [move your model
4948
to the CPU](gpu.md) with `cpu(model)` before saving it.
5049

51-
## Saving Model Weights
50+
!!! warning
5251

53-
In some cases it may be useful to save only the model parameters themselves, and
54-
rebuild the model architecture in your code. You can use `params(model)` to get
55-
model parameters.
52+
Previous versions of Flux suggested saving only the model weights using
53+
`@save "mymodel.bson" params(model)`.
54+
This is no longer recommended and even strongly discouraged.
55+
Saving models this way will only store the trainable parameters which
56+
will result in incorrect behavior for layers like `BatchNorm`.
5657

5758
```jldoctest saving
5859
julia> model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax)
@@ -64,29 +65,26 @@ Chain(
6465
6566
julia> weights = Flux.params(model);
6667
67-
julia> using BSON: @save
68-
69-
julia> @save "mymodel.bson" weights
70-
```
71-
72-
You can easily load parameters back into a model with `Flux.loadparams!`.
68+
Loading the model as shown above will return a new model with the stored parameters.
69+
But sometimes you already have a model, and you want to load stored parameters into it.
70+
This can be done as
7371
7472
```julia
75-
julia> model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax)
76-
Chain(
77-
Dense(10 => 5, relu), # 55 parameters
78-
Dense(5 => 2), # 12 parameters
79-
NNlib.softmax,
80-
) # Total: 4 arrays, 67 parameters, 524 bytes.
73+
using Flux: loadmodel!
74+
using BSON: @load
8175
82-
julia> using BSON: @load
76+
# some predefined model
77+
model = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax)
8378
84-
julia> @load "mymodel.bson" weights
85-
86-
julia> Flux.loadparams!(model, weights)
79+
# load one model into another
80+
model = loadmodel!(model, @load("mymodel.bson"))
8781
```
8882

89-
The new `model` we created will now be identical to the one we saved parameters for.
83+
This ensures that the model loaded from `"mymodel.bson"` matches the structure of `model`. [`Flux.loadmodel!`](@ref) is also convenient for copying parameters between models in memory.
84+
85+
```@docs
86+
Flux.loadmodel!
87+
```
9088

9189
## Checkpointing
9290

src/Flux.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ include("layers/normalise.jl")
4646
include("layers/upsample.jl")
4747
include("layers/show.jl")
4848

49+
include("loading.jl")
50+
4951
include("outputsize.jl")
5052

5153
include("data/Data.jl")

src/deprecations.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# v0.12 deprecations
22

33
function ones(dims...)
4-
Base.depwarn("Flux.ones(size...) is deprecated, please use Flux.ones32(size...) or Base.ones(Float32, size...)", :ones)
4+
Base.depwarn("Flux.ones(size...) is deprecated, please use Flux.ones32(size...) or Base.ones(Float32, size...)", :ones, force=true)
55
Base.ones(Float32, dims...)
66
end
77
ones(T::Type, dims...) = Base.ones(T, dims...)
88

99
function zeros(dims...)
10-
Base.depwarn("Flux.zeros(size...) is deprecated, please use Flux.zeros32(size...) or Base.zeros(Float32, size...)", :zeros)
10+
Base.depwarn("Flux.zeros(size...) is deprecated, please use Flux.zeros32(size...) or Base.zeros(Float32, size...)", :zeros, force=true)
1111
Base.zeros(Float32, dims...)
1212
end
1313
zeros(T::Type, dims...) = Base.zeros(T, dims...)
@@ -39,6 +39,25 @@ function Optimise.update!(x::AbstractArray, x̄)
3939
x .-=
4040
end
4141

42+
function Diagonal(size::Integer...; kw...)
43+
Base.depwarn("Flux.Diagonal is now Flux.Scale, and also allows an activation function.", :Diagonal)
44+
Scale(size...; kw...)
45+
end
46+
function Diagonal(size::Tuple; kw...)
47+
Base.depwarn("Flux.Diagonal is now Flux.Scale, and also allows an activation function.", :Diagonal)
48+
Scale(size...; kw...)
49+
end
50+
51+
# Deprecate this eventually once saving models w/o structure is no more
52+
function loadparams!(m, xs)
53+
Base.depwarn("loadparams! will be deprecated eventually. Use loadmodel! instead.", :loadparams!)
54+
for (p, x) in zip(params(m), xs)
55+
size(p) == size(x) ||
56+
error("Expected param size $(size(p)), got $(size(x))")
57+
copyto!(p, x)
58+
end
59+
end
60+
4261
# Channel notation: Changed to match Conv, but very softly deprecated!
4362
# Perhaps change to @deprecate for v0.14, but there is no plan to remove these.
4463
Dense(in::Integer, out::Integer, σ = identity; kw...) =

src/functor.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,6 @@ function params(m...)
8585
return ps
8686
end
8787

88-
function loadparams!(m, xs)
89-
for (p, x) in zip(params(m), xs)
90-
size(p) == size(x) ||
91-
error("Expected param size $(size(p)), got $(size(x))")
92-
copyto!(p, x)
93-
end
94-
end
95-
9688
struct FluxCUDAAdaptor end
9789
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
9890
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))

src/layers/basic.jl

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,8 @@ end
156156
@functor Dense
157157

158158
function (a::Dense)(x::AbstractVecOrMat)
159-
W, b = a.weight, a.bias
160159
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
161-
return σ.(W*x .+ b)
160+
return σ.(a.weight * x .+ a.bias)
162161
end
163162

164163
(a::Dense)(x::AbstractArray) =
@@ -172,38 +171,69 @@ function Base.show(io::IO, l::Dense)
172171
end
173172

174173
"""
175-
Diagonal(size::Integer...; bias=true, init=ones32)
176-
Diagonal(scale::AbstractArray, [bias])
174+
Scale(size::Integer..., σ=identity; bias=true, init=ones32)
175+
Scale(scale::AbstractArray, [bias, σ])
177176
178-
Create an element-wise linear layer, which performs
177+
Create an element-wise layer, whose forward pass is given by:
179178
180-
y = scale .* x .+ bias
179+
y = σ.(scale .* x .+ bias)
181180
182-
with no activation function.
183-
181+
This uses `.*` instead of matrix multiplication `*` of [`Dense`](@ref).
182+
184183
The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`,
185184
with `init=ones32` by default. You may specify the function `init`,
186185
turn off trainable bias with `bias=false`, or provide the array(s) explicitly.
187186
188-
Used by [`LayerNorm`](@ref).
187+
Used by [`LayerNorm`](@ref) with `affine=true`.
188+
189+
# Examples
190+
```jldoctest
191+
julia> a = Flux.Scale(2)
192+
Scale(2) # 4 parameters
193+
194+
julia> Flux.params(a)
195+
Params([Float32[1.0, 1.0], Float32[0.0, 0.0]])
196+
197+
julia> a([1 2 3])
198+
2×3 Matrix{Float32}:
199+
1.0 2.0 3.0
200+
1.0 2.0 3.0
201+
202+
julia> b = Flux.Scale([1 2 3 4], false, abs2)
203+
Scale(1, 4, abs2; bias=false) # 4 parameters
204+
205+
julia> b([1, 10])
206+
2×4 Matrix{Int64}:
207+
1 4 9 16
208+
100 400 900 1600
209+
210+
julia> Flux.params(b)
211+
Params([[1 2 3 4]])
212+
```
189213
"""
190-
struct Diagonal{A<:AbstractArray, B}
214+
struct Scale{F, A<:AbstractArray, B}
191215
scale::A
192216
bias::B
193-
function Diagonal(W::M, bias = true) where M<:AbstractArray
194-
b = create_bias(W, bias, size(W)...)
195-
new{M, typeof(b)}(W, b)
217+
σ::F
218+
function Scale(scale::A, bias::B = true, σ::F = identity) where {A<:AbstractArray, B<:Union{Bool, AbstractArray}, F}
219+
b = create_bias(scale, bias, size(scale)...)
220+
new{F, A, typeof(b)}(scale, b, σ)
196221
end
197222
end
198223

199-
Diagonal(sz::Integer...; bias = true, init = ones32) = Diagonal(init(sz...), bias)
224+
Scale(s1::Integer, s23::Integer...; bias = true, init = ones32, _act = identity) = Scale(init(s1, s23...), bias, _act)
225+
Scale(size_act...; bias = true, init = ones32) = Scale(size_act[1:end-1]...; bias, init, _act = size_act[end])
200226

201-
@functor Diagonal
227+
@functor Scale
202228

203-
(a::Diagonal)(x) = a.scale .* x .+ a.bias
229+
function (a::Scale)(x::AbstractArray)
230+
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
231+
σ.(a.scale .* x .+ a.bias)
232+
end
204233

205-
function Base.show(io::IO, l::Diagonal)
206-
print(io, "Diagonal(", join(size(l.scale), ", "))
234+
function Base.show(io::IO, l::Scale)
235+
print(io, "Scale(", join(size(l.scale), ", "))
236+
l.σ == identity || print(io, ", ", l.σ)
207237
l.bias == false && print(io, "; bias=false")
208238
print(io, ")")
209239
end
@@ -212,7 +242,7 @@ end
212242
Maxout(layers...)
213243
Maxout(f, n_alts)
214244
215-
This contains a number of internal layes, each of which receives the same input.
245+
This contains a number of internal layers, each of which receives the same input.
216246
Its output is the elementwise maximum of the the internal layers' outputs.
217247
218248
Instead of defining layers individually, you can provide a zero-argument function

0 commit comments

Comments
 (0)