Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/src/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ Several normalisation layers behave differently under training and inference (te
The functions `Flux.trainmode!` and `Flux.testmode!` let you manually specify which behaviour you want. When called on a model, they will place all layers within the model into the specified mode.

```@docs
Flux.testmode!
testmode!(::Any)
testmode!(::Any, ::Any)
trainmode!
```
17 changes: 17 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,23 @@ function update!(opt::Optimise.AbstractOptimiser, ::Params, grads::Union{Tuple,
""")
end

"""
trainmode!(m, active)

!!! warning
This two-argument method is deprecated.

Possible values of `active` are:
- `true` for training, or
- `false` for testing, same as [`testmode!`](@ref)`(m)`
- `:auto` or `nothing` for Flux to detect training automatically.
"""
function trainmode!(m, active::Bool)
Base.depwarn("trainmode!(m, active::Bool) is deprecated", :trainmode)
testmode!(m, !active)
end


# v0.14 deprecations

# Enable these when 0.14 is released, and delete const ClipGrad = Optimise.ClipValue etc:
Expand Down
68 changes: 48 additions & 20 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,64 @@ import Functors: Functors, @functor, functor, fmap, isleaf
using SparseArrays: AbstractSparseArray

"""
testmode!(m, mode = true)
testmode!(m, [mode])

Set a layer or model's test mode (see below).
Using `:auto` mode will treat any gradient computation as training.
Set a layer, or all layers in a model, to test mode.
This disables the effect of [`Dropout`](@ref) and
some other regularisation layers.

_Note_: if you manually set a model into test mode, you need to manually place
it back into train mode during training phase.
If you manually set a model into test mode, you need to manually place
it back into train mode during training phase, using [`trainmode!`](@ref).

There is an optional second argument, which takes a symbol `:auto` to
reset all layers back to the default automatic mode.

# Example

```jldoctest
julia> d = Dropout(0.3)
Dropout(0.3)

julia> testmode!(d) # dropout is now always disabled
Dropout(0.3, active=false)

julia> trainmode!(d) # dropout is now always enabled
Dropout(0.3, active=true)

julia> trainmode!(d, :auto) # back to default
Dropout(0.3)
```
"""
testmode!(m) = testmode!(m, true)

Possible values include:
- `false` for training
- `true` for testing
- `:auto` or `nothing` for Flux to detect the mode automatically
"""
testmode!(m, mode = true) = (foreach(x -> testmode!(x, mode), trainable(m)); m)
trainmode!(m)

Set a layer, or all layers in a model, to training mode.
Opposite to [`testmode!`](@ref), see further details there.
"""
trainmode!(m, mode = true)
trainmode!(m) = testmode!(m, false)
trainmode!(m, mode::Symbol) = testmode!(m, mode)
trainmode!(m, ::Nothing) = testmode!(m, nothing) # why do we have so much API?

Set a layer of model's train mode (see below).
Symmetric to [`testmode!`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)`).
"""
testmode!(model, inactive)

_Note_: if you manually set a model into train mode, you need to manually place
it into test mode during testing phase.
This two-argument method is largely internal. It recurses into the `model`,
and until a method like `testmode!(d::Dropout, inactive)` alters the activity of a layer.

Possible values include:
- `true` for training
- `false` for testing
- `:auto` or `nothing` for Flux to detect the mode automatically
Possible values of `inactive` are:
- `true` for testing, i.e. `active=false`
- `false` for training, same as [`trainmode!`](@ref)`(m)`
- `:auto` or `nothing` for Flux to detect training automatically.
"""
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)
function testmode!(m, mode)
if mode isa Symbol && mode !== :auto
throw(ArgumentError("testmode! accepts only the symbol :auto, got :$mode"))
end
foreach(x -> testmode!(x, mode), trainable(m))
m
end

function params!(p::Params, x, seen = IdSet())
if x isa AbstractArray{<:Number} && Functors.isleaf(x)
Expand Down
62 changes: 38 additions & 24 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# Internal function, used only for layers defined in this file.
_isactive(m, x) = isnothing(m.active) ? NNlib.within_gradient(x) : m.active

# Internal function, used only in this file.
_tidy_active(mode::Bool) = mode
_tidy_active(::Nothing) = nothing
_tidy_active(mode) = mode === :auto ? nothing : throw(ArgumentError("active = $(repr(mode)) is not accepted, must be true/false/nothing or :auto"))

"""
Dropout(p; [dims, rng])
Dropout(p; [dims, rng, active])

Layer implementing [dropout](https://arxiv.org/abs/1207.0580) with the given probability.
This is used as a regularisation, i.e. to reduce overfitting.
Expand All @@ -12,7 +17,8 @@ or else scales it by `1 / (1 - p)`, using the [`NNlib.dropout`](@ref) function.
While testing, it has no effect.

By default the mode will switch automatically, but it can also
be controlled manually via [`Flux.testmode!`](@ref).
be controlled manually via [`Flux.testmode!`](@ref),
or by passing keyword `active=true` for training mode.

By default every input is treated independently. With the `dims` keyword,
instead it takes a random choice only along that dimension.
Expand All @@ -36,7 +42,11 @@ julia> m(ones(2, 7)) # test mode, no effect
2.0 2.0 2.0 2.0 2.0 2.0 2.0
2.0 2.0 2.0 2.0 2.0 2.0 2.0

julia> Flux.trainmode!(m); # equivalent to use within gradient
julia> Flux.trainmode!(m) # equivalent to use within gradient
Chain(
Dense(2 => 3), # 9 parameters
Dropout(0.4, active=true),
)

julia> m(ones(2, 7))
3×7 Matrix{Float64}:
Expand All @@ -63,9 +73,9 @@ mutable struct Dropout{F<:Real,D,R<:AbstractRNG}
end
Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng_value())

function Dropout(p::Real; dims=:, rng = default_rng_value())
function Dropout(p::Real; dims=:, active::Union{Bool,Nothing} = nothing, rng = default_rng_value())
0 ≤ p ≤ 1 || throw(ArgumentError("Dropout expects 0 ≤ p ≤ 1, got p = $p"))
Dropout(p, dims, nothing, rng)
Dropout(p, dims, active, rng)
end

@functor Dropout
Expand All @@ -74,16 +84,17 @@ trainable(a::Dropout) = (;)
(a::Dropout)(x) = dropout(a.rng, x, a.p * _isactive(a, x); dims=a.dims)

testmode!(m::Dropout, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
(m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m)

function Base.show(io::IO, d::Dropout)
print(io, "Dropout(", d.p)
d.dims != (:) && print(io, ", dims = $(repr(d.dims))")
d.dims != (:) && print(io, ", dims=", d.dims)
d.active == nothing || print(io, ", active=", d.active)
print(io, ")")
end

"""
AlphaDropout(p; rng = default_rng_value())
AlphaDropout(p; [rng, active])

A dropout layer. Used in
[Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
Expand Down Expand Up @@ -112,13 +123,13 @@ mutable struct AlphaDropout{F,R<:AbstractRNG}
p::F
active::Union{Bool, Nothing}
rng::R
function AlphaDropout(p, active, rng)
@assert 0 ≤ p ≤ 1
new{typeof(p), typeof(rng)}(p, active, rng)
end
end

AlphaDropout(p, active) = AlphaDropout(p, active, default_rng_value())
AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng)
function AlphaDropout(p; rng = default_rng_value(), active::Union{Bool,Nothing} = nothing)
0 ≤ p ≤ 1 || throw(ArgumentError("AlphaDropout expects 0 ≤ p ≤ 1, got p = $p"))
AlphaDropout(p, active, rng)
end

@functor AlphaDropout
trainable(a::AlphaDropout) = (;)
Expand All @@ -138,7 +149,7 @@ function (a::AlphaDropout)(x::AbstractArray{T}) where T
end

testmode!(m::AlphaDropout, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
(m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m)

"""
LayerNorm(size..., λ=identity; affine=true, ϵ=1fe-5)
Expand Down Expand Up @@ -257,7 +268,7 @@ ChainRulesCore.@non_differentiable _track_stats!(::Any...)
"""
BatchNorm(channels::Integer, λ=identity;
initβ=zeros32, initγ=ones32,
affine = true, track_stats = true,
affine=true, track_stats=true, active=nothing,
ϵ=1f-5, momentum= 0.1f0)

[Batch Normalization](https://arxiv.org/abs/1502.03167) layer.
Expand Down Expand Up @@ -310,7 +321,7 @@ end

function BatchNorm(chs::Int, λ=identity;
initβ=zeros32, initγ=ones32,
affine=true, track_stats=true,
affine=true, track_stats=true, active::Union{Bool,Nothing}=nothing,
ϵ=1f-5, momentum=0.1f0)

β = affine ? initβ(chs) : nothing
Expand All @@ -321,7 +332,7 @@ function BatchNorm(chs::Int, λ=identity;
return BatchNorm(λ, β, γ,
μ, σ², ϵ, momentum,
affine, track_stats,
nothing, chs)
active, chs)
end

@functor BatchNorm
Expand All @@ -335,12 +346,13 @@ function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N}
end

testmode!(m::BatchNorm, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
(m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m)

function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(l.chs)")
(l.λ == identity) || print(io, ", $(l.λ)")
hasaffine(l) || print(io, ", affine=false")
l.active == nothing || print(io, ", active=", l.active)
print(io, ")")
end

Expand Down Expand Up @@ -399,7 +411,7 @@ end

function InstanceNorm(chs::Int, λ=identity;
initβ=zeros32, initγ=ones32,
affine=false, track_stats=false,
affine=false, track_stats=false, active::Union{Bool,Nothing}=nothing,
ϵ=1f-5, momentum=0.1f0)

β = affine ? initβ(chs) : nothing
Expand All @@ -410,7 +422,7 @@ function InstanceNorm(chs::Int, λ=identity;
return InstanceNorm(λ, β, γ,
μ, σ², ϵ, momentum,
affine, track_stats,
nothing, chs)
active, chs)
end

@functor InstanceNorm
Expand All @@ -424,12 +436,13 @@ function (l::InstanceNorm)(x::AbstractArray{T,N}) where {T,N}
end

testmode!(m::InstanceNorm, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
(m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m)

function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(l.chs)")
l.λ == identity || print(io, ", $(l.λ)")
hasaffine(l) || print(io, ", affine=false")
l.active == nothing || print(io, ", active=", l.active)
print(io, ")")
end

Expand Down Expand Up @@ -495,7 +508,7 @@ trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;)

function GroupNorm(chs::Int, G::Int, λ=identity;
initβ=zeros32, initγ=ones32,
affine=true, track_stats=false,
affine=true, track_stats=false, active::Union{Bool,Nothing}=nothing,
ϵ=1f-5, momentum=0.1f0)

if track_stats
Expand All @@ -514,7 +527,7 @@ end
μ, σ²,
ϵ, momentum,
affine, track_stats,
nothing, chs)
active, chs)
end

function (gn::GroupNorm)(x::AbstractArray)
Expand All @@ -529,13 +542,14 @@ function (gn::GroupNorm)(x::AbstractArray)
end

testmode!(m::GroupNorm, mode = true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
(m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m)

function Base.show(io::IO, l::GroupNorm)
# print(io, "GroupNorm($(join(size(l.β), ", "))", ", ", l.G)
print(io, "GroupNorm($(l.chs), $(l.G)")
l.λ == identity || print(io, ", ", l.λ)
hasaffine(l) || print(io, ", affine=false")
l.active == nothing || print(io, ", active=", l.active)
print(io, ")")
end

Expand Down
18 changes: 9 additions & 9 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Test, Random
import Flux: activations
using Flux: activations

@testset "basic" begin
@testset "helpers" begin
Expand All @@ -16,11 +16,11 @@ import Flux: activations
end

@testset "Chain" begin
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn(10))
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn(10))
@test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn32(10))
@test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn32(10))
# numeric test should be put into testset of corresponding layer

@test_nowarn Chain(first = Dense(10, 5, σ), second = Dense(5, 2))(randn(10))
@test_nowarn Chain(first = Dense(10, 5, σ), second = Dense(5, 2))(randn32(10))
m = Chain(first = Dense(10, 5, σ), second = Dense(5, 2))
@test m[:first] == m[1]
@test m[1:2] == m
Expand Down Expand Up @@ -72,10 +72,10 @@ import Flux: activations
@test_throws MethodError Dense(rand(5), rand(5), tanh)
end
@testset "dimensions" begin
@test length(Dense(10, 5)(randn(10))) == 5
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
@test_throws MethodError Dense(10, 5).(randn(10)) # avoid broadcasting
@test length(Dense(10 => 5)(randn32(10))) == 5
@test_throws DimensionMismatch Dense(10 => 5)(randn32(1))
@test_throws MethodError Dense(10 => 5)(1) # avoid broadcasting
@test_throws MethodError Dense(10 => 5).(randn32(10)) # avoid broadcasting
@test size(Dense(10, 5)(randn(10))) == (5,)
@test size(Dense(10, 5)(randn(10,2))) == (5,2)
@test size(Dense(10, 5)(randn(10,2,3))) == (5,2,3)
Expand Down Expand Up @@ -333,7 +333,7 @@ import Flux: activations
y = m(x)
@test y isa Array{Float32, 3}
@test size(y) == (embed_size, 3, 4)
x3 = onehotbatch(x, 1:1:vocab_size)
x3 = Flux.onehotbatch(x, 1:1:vocab_size)
@test size(x3) == (vocab_size, 3, 4)
y3 = m(x3)
@test size(y3) == (embed_size, 3, 4)
Expand Down
Loading