Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,52 @@ import Functors: Functors, @functor, functor, fmap, isleaf
using SparseArrays: AbstractSparseArray

"""
testmode!(m, mode = true)
testmode!(m, inactive = true)

Set a layer or model's test mode (see below).
Using `:auto` mode will treat any gradient computation as training.
Set a layer, or all layers in a model, to test mode.
This disables the effect of [`Dropout`](@ref), and similar layers.

_Note_: if you manually set a model into test mode, you need to manually place
it back into train mode during training phase.

Possible values include:
- `false` for training
Possible values of optional 2nd argument `inactive` are:
- `true` for testing
- `:auto` or `nothing` for Flux to detect the mode automatically
- `false` for training, same as [`trainmode!`](@ref)`(m)`
- `:auto` or `nothing` for Flux to detect training automatically.

# Example

```jldoctest
julia> d = Dropout(0.3)
Dropout(0.3)

julia> testmode!(d) # dropout is now always disabled
Dropout(0.3, active=false)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we happy with the name active? This existed as a field name, but not previously exposed.

testmode! and trainmode! both had a positional argument called mode with opposite meanings. I made these active + inactive to match.

Copy link
Member

@ToucheSir ToucheSir Apr 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the double negative a little confusing but can't come up with a better word. For a future PR, would it make sense to add a third setactive! function and move the true/false/auto/nothing handling logic to that? Then trainmode! and testmode! lose their second arg and become shortcuts for setactive!(model, <true|false>). Either way, we could even use an enum if we're feeling fancy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's awful that they both do everything. It would be OK if either accepted :auto, but never true/false. Maybe that's a deprecation goal?

There could also be a 3rd function, but two is already a lot. Or the 3rd could replace both, but that's more churn.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But more immediately, your enum suggestion could read Dropout(0.5; mode=:test) and Dropout(0.5; mode=:train). That has the advantage of always being one type. It's a little more indirect -- it tells you what the layer is intended for, not what it does.

Copy link
Member

@ToucheSir ToucheSir Apr 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be OK if either accepted :auto, but never true/false. Maybe that's a deprecation goal?

That's a good idea! Can we keep mode then? In that case, have we considered something like enabled instead of mode or (in)active?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tricky thing is that trainmode!(m, ::Bool) recurses to itself, and is what's overloaded by layers. Presumably some layers in packages may overload this too.

Deprecating that method and changing the recursion to use something else means that we will break any packages which rely on it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Maybe 3 functions isn't that bad after all then if it makes the deprecation path easier. PyTorch has .train(), .eval() and module.training = {true,false}.


julia> trainmode!(d) # dropout is now always enabled
Dropout(0.3, active=true)

julia> trainmode!(d, :auto) # back to default
Dropout(0.3)
```
"""
testmode!(m, mode = true) = (foreach(x -> testmode!(x, mode), trainable(m)); m)
testmode!(m, inactive = true) = (foreach(x -> testmode!(x, inactive), trainable(m)); m)

"""
trainmode!(m, mode = true)
trainmode!(m, active = true)

Set a layer of model's train mode (see below).
Symmetric to [`testmode!`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)`).
Set a layer, or all layers in a model, to training mode.
Opposite to [`testmode!`](@ref) (i.e. `trainmode!(m, active) == testmode!(m, !active)`).

_Note_: if you manually set a model into train mode, you need to manually place
it into test mode during testing phase.

Possible values include:
Possible values of optional 2nd argument `active` are:
- `true` for training
- `false` for testing
- `:auto` or `nothing` for Flux to detect the mode automatically
- `:auto` or `nothing` for Flux to detect training automatically
"""
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)
trainmode!(m, active = true) = active isa Bool ? testmode!(m, !active) : testmode!(m, active)

function params!(p::Params, x, seen = IdSet())
if x isa AbstractArray{<:Number} && Functors.isleaf(x)
Expand Down
56 changes: 35 additions & 21 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# Internal function, used only for layers defined in this file.
_isactive(m, x) = isnothing(m.active) ? NNlib.within_gradient(x) : m.active

# Internal function, used only in this file.
_tidy_active(mode::Bool) = mode
_tidy_active(::Nothing) = nothing
_tidy_active(mode) = mode === :auto ? nothing : throw(ArgumentError("active = $(repr(mode)) is not accepted, must be true/false/nothing or :auto"))

"""
Dropout(p; [dims, rng])
Dropout(p; [dims, rng, active])

Layer implementing [dropout](https://arxiv.org/abs/1207.0580) with the given probability.
This is used as a regularisation, i.e. to reduce overfitting.
Expand All @@ -12,7 +17,8 @@ or else scales it by `1 / (1 - p)`, using the [`NNlib.dropout`](@ref) function.
While testing, it has no effect.

By default the mode will switch automatically, but it can also
be controlled manually via [`Flux.testmode!`](@ref).
be controlled manually via [`Flux.testmode!`](@ref),
or by passing keyword `active=true` for training mode.

By default every input is treated independently. With the `dims` keyword,
instead it takes a random choice only along that dimension.
Expand All @@ -36,7 +42,11 @@ julia> m(ones(2, 7)) # test mode, no effect
2.0 2.0 2.0 2.0 2.0 2.0 2.0
2.0 2.0 2.0 2.0 2.0 2.0 2.0

julia> Flux.trainmode!(m); # equivalent to use within gradient
julia> Flux.trainmode!(m) # equivalent to use within gradient
Chain(
Dense(2 => 3), # 9 parameters
Dropout(0.4, active=true),
)

julia> m(ones(2, 7))
3×7 Matrix{Float64}:
Expand All @@ -63,9 +73,9 @@ mutable struct Dropout{F<:Real,D,R<:AbstractRNG}
end
Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng_value())

function Dropout(p::Real; dims=:, rng = default_rng_value())
function Dropout(p::Real; dims=:, active = nothing, rng = default_rng_value())
0 ≤ p ≤ 1 || throw(ArgumentError("Dropout expects 0 ≤ p ≤ 1, got p = $p"))
Dropout(p, dims, nothing, rng)
Dropout(p, dims, _tidy_active(active), rng)
end

@functor Dropout
Expand All @@ -74,16 +84,17 @@ trainable(a::Dropout) = (;)
(a::Dropout)(x) = dropout(a.rng, x, a.p * _isactive(a, x); dims=a.dims)

testmode!(m::Dropout, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
(m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m)

function Base.show(io::IO, d::Dropout)
print(io, "Dropout(", d.p)
d.dims != (:) && print(io, ", dims = $(repr(d.dims))")
d.dims != (:) && print(io, ", dims=", d.dims)
d.active == nothing || print(io, ", active=", d.active)
print(io, ")")
end

"""
AlphaDropout(p; rng = default_rng_value())
AlphaDropout(p; [rng, active])

A dropout layer. Used in
[Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
Expand Down Expand Up @@ -114,11 +125,11 @@ mutable struct AlphaDropout{F,R<:AbstractRNG}
rng::R
function AlphaDropout(p, active, rng)
@assert 0 ≤ p ≤ 1
new{typeof(p), typeof(rng)}(p, active, rng)
new{typeof(p), typeof(rng)}(p, _tidy_active(active), rng)
end
end
AlphaDropout(p, active) = AlphaDropout(p, active, default_rng_value())
AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng)
AlphaDropout(p; rng = default_rng_value(), active=nothing) = AlphaDropout(p, active, rng)

@functor AlphaDropout
trainable(a::AlphaDropout) = (;)
Expand All @@ -138,7 +149,7 @@ function (a::AlphaDropout)(x::AbstractArray{T}) where T
end

testmode!(m::AlphaDropout, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
(m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m)

"""
LayerNorm(size..., λ=identity; affine=true, ϵ=1fe-5)
Expand Down Expand Up @@ -257,7 +268,7 @@ ChainRulesCore.@non_differentiable _track_stats!(::Any...)
"""
BatchNorm(channels::Integer, λ=identity;
initβ=zeros32, initγ=ones32,
affine = true, track_stats = true,
affine=true, track_stats=true, active=:auto,
ϵ=1f-5, momentum= 0.1f0)

[Batch Normalization](https://arxiv.org/abs/1502.03167) layer.
Expand Down Expand Up @@ -310,7 +321,7 @@ end

function BatchNorm(chs::Int, λ=identity;
initβ=zeros32, initγ=ones32,
affine=true, track_stats=true,
affine=true, track_stats=true, active=nothing,
ϵ=1f-5, momentum=0.1f0)

β = affine ? initβ(chs) : nothing
Expand All @@ -321,7 +332,7 @@ function BatchNorm(chs::Int, λ=identity;
return BatchNorm(λ, β, γ,
μ, σ², ϵ, momentum,
affine, track_stats,
nothing, chs)
_tidy_active(active), chs)
end

@functor BatchNorm
Expand All @@ -335,12 +346,13 @@ function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N}
end

testmode!(m::BatchNorm, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
(m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m)

function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(l.chs)")
(l.λ == identity) || print(io, ", $(l.λ)")
hasaffine(l) || print(io, ", affine=false")
l.active == nothing || print(io, ", active=", l.active)
print(io, ")")
end

Expand Down Expand Up @@ -399,7 +411,7 @@ end

function InstanceNorm(chs::Int, λ=identity;
initβ=zeros32, initγ=ones32,
affine=false, track_stats=false,
affine=false, track_stats=false, active=nothing,
ϵ=1f-5, momentum=0.1f0)

β = affine ? initβ(chs) : nothing
Expand All @@ -410,7 +422,7 @@ function InstanceNorm(chs::Int, λ=identity;
return InstanceNorm(λ, β, γ,
μ, σ², ϵ, momentum,
affine, track_stats,
nothing, chs)
_tidy_active(active), chs)
end

@functor InstanceNorm
Expand All @@ -424,12 +436,13 @@ function (l::InstanceNorm)(x::AbstractArray{T,N}) where {T,N}
end

testmode!(m::InstanceNorm, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
(m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m)

function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(l.chs)")
l.λ == identity || print(io, ", $(l.λ)")
hasaffine(l) || print(io, ", affine=false")
l.active == nothing || print(io, ", active=", l.active)
print(io, ")")
end

Expand Down Expand Up @@ -495,7 +508,7 @@ trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;)

function GroupNorm(chs::Int, G::Int, λ=identity;
initβ=zeros32, initγ=ones32,
affine=true, track_stats=false,
affine=true, track_stats=false, active=nothing,
ϵ=1f-5, momentum=0.1f0)

if track_stats
Expand All @@ -514,7 +527,7 @@ end
μ, σ²,
ϵ, momentum,
affine, track_stats,
nothing, chs)
_tidy_active(active), chs)
end

function (gn::GroupNorm)(x::AbstractArray)
Expand All @@ -529,13 +542,14 @@ function (gn::GroupNorm)(x::AbstractArray)
end

testmode!(m::GroupNorm, mode = true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
(m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m)

function Base.show(io::IO, l::GroupNorm)
# print(io, "GroupNorm($(join(size(l.β), ", "))", ", ", l.G)
print(io, "GroupNorm($(l.chs), $(l.G)")
l.λ == identity || print(io, ", ", l.λ)
hasaffine(l) || print(io, ", affine=false")
l.active == nothing || print(io, ", active=", l.active)
print(io, ")")
end

Expand Down
25 changes: 25 additions & 0 deletions test/layers/normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ evalwgrad(f, x...) = pullback(f, x...)[1]
y = evalwgrad(m, x)
@test count(a->a==0, y) > 50

# Keyword active=false
m2 = Dropout(0.9; active=false, rng_kwargs...)
y2 = evalwgrad(m2, x)
@test count(a->a==0, y2) == 0

x = rand(Float32, 100)
m = Chain(Dense(100,100),
Dropout(0.9; rng_kwargs...))
Expand Down Expand Up @@ -73,6 +78,10 @@ evalwgrad(f, x...) = pullback(f, x...)[1]
@test cpu(m).rng === only(values(rng_kwargs))
end
end

@test Dropout(0.5; active=:auto).active === nothing
@test Dropout(0.5; active=true).active === true
@test_throws ArgumentError Dropout(0.5; active=:something_else)
end

@testset "AlphaDropout" begin
Expand Down Expand Up @@ -119,6 +128,10 @@ end
@test cpu(m).rng === only(values(rng_kwargs))
end
end

@test AlphaDropout(0.5; active=:auto).active === nothing
@test AlphaDropout(0.5; active=true).active === true
@test_throws ArgumentError AlphaDropout(0.5; active=:something_else)
end

@testset "BatchNorm" begin
Expand Down Expand Up @@ -211,6 +224,10 @@ end
@test length(Flux.params(BatchNorm(10))) == 2
@test length(Flux.params(BatchNorm(10, affine=true))) == 2
@test length(Flux.params(BatchNorm(10, affine=false))) == 0

@test BatchNorm(5; active=:auto).active === nothing
@test BatchNorm(5; active=true).active === true
@test_throws ArgumentError BatchNorm(5; active=:something_else)
end

@testset "InstanceNorm" begin
Expand Down Expand Up @@ -342,6 +359,10 @@ end
@test length(Flux.params(InstanceNorm(10))) == 0
@test length(Flux.params(InstanceNorm(10, affine=true))) == 2
@test length(Flux.params(InstanceNorm(10, affine=false))) == 0

@test InstanceNorm(5; active=:auto).active === nothing
@test InstanceNorm(5; active=true).active === true
@test_throws ArgumentError InstanceNorm(5; active=:something_else)
end

@testset "LayerNorm" begin
Expand Down Expand Up @@ -465,6 +486,10 @@ end
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
@test BN(x) ≈ GN(x)
end

@test GroupNorm(5, 5; active=:auto).active === nothing
@test GroupNorm(5, 5; active=true).active === true
@test_throws ArgumentError GroupNorm(5, 5; active=:something_else)
end

@testset "second derivatives" begin
Expand Down