-
-
Notifications
You must be signed in to change notification settings - Fork 614
Move dropout
to NNlib
#2150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move dropout
to NNlib
#2150
Changes from all commits
e96e4ef
fec2d8e
6c84a6c
eab0b15
4ab93b3
0e396a6
f42f475
d7cc49d
28ac4c4
9e99422
fc9855b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -123,7 +123,7 @@ LayerNorm | |
InstanceNorm | ||
GroupNorm | ||
Flux.normalise | ||
Flux.dropout | ||
NNlib.dropout | ||
``` | ||
|
||
### Test vs. Train | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,112 +1,77 @@ | ||
|
||
# Internal function, used only for layers defined in this file. | ||
_isactive(m, x) = isnothing(m.active) ? NNlib.within_gradient(x) : m.active | ||
|
||
_dropout_shape(s, ::Colon) = size(s) | ||
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...) | ||
|
||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) | ||
|
||
""" | ||
dropout([rng = rng_from_array(x)], x, p; dims=:, active=true) | ||
|
||
The dropout function. If `active` is `true`, | ||
for each input, either sets that input to `0` (with probability | ||
`p`) or scales it by `1 / (1 - p)`. `dims` specifies the unbroadcasted dimensions, | ||
e.g. `dims=1` applies dropout along columns and `dims=2` along rows. | ||
If `active` is `false`, it just returns the input `x`. | ||
|
||
Specify `rng` for custom RNGs instead of the default RNG. | ||
Note that custom RNGs are only supported on the CPU. | ||
|
||
Warning: when using this function, you have to manually manage the activation | ||
state. Usually in fact, dropout is used while training | ||
but is deactivated in the inference phase. This can be | ||
automatically managed using the [`Dropout`](@ref) layer instead of the | ||
`dropout` function. | ||
|
||
The [`Dropout`](@ref) layer is what you should use in most scenarios. | ||
""" | ||
function dropout(rng, x, p; dims=:, active::Bool=true) | ||
active || return x | ||
y = dropout_mask(rng, x, p, dims=dims) | ||
return x .* y | ||
end | ||
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...) | ||
|
||
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) | ||
dropout_mask(rng, x::CuArray, p; kwargs...) = | ||
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays.")) | ||
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...) | ||
function _dropout_mask(rng, x, p; dims=:) | ||
realfptype = float(real(eltype(x))) | ||
y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims))) | ||
y .= _dropout_kernel.(y, p, 1 - p) | ||
return y | ||
end | ||
|
||
# TODO move this to NNlib | ||
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any) | ||
|
||
""" | ||
Dropout(p; dims=:, rng = default_rng_value()) | ||
Dropout(p; [dims, rng]) | ||
|
||
Dropout layer. | ||
Layer implementing [dropout](https://arxiv.org/abs/1207.0580) with the given probability. | ||
This is used as a regularisation, i.e. to reduce overfitting. | ||
|
||
While training, for each input, this layer either sets that input to `0` (with probability | ||
`p`) or scales it by `1 / (1 - p)`. To apply dropout along certain dimension(s), specify the | ||
`dims` keyword. e.g. `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input | ||
(also called 2D dropout). This is used as a regularisation, i.e. it reduces overfitting during | ||
training. | ||
While training, it sets each input to `0` (with probability `p`) | ||
or else scales it by `1 / (1 - p)`, using the [`NNlib.dropout`](@ref) function. | ||
While testing, it has no effect. | ||
|
||
In the forward pass, this layer applies the [`Flux.dropout`](@ref) function. See that for more | ||
details. | ||
By default the mode will switch automatically, but it can also | ||
be controlled manually via [`Flux.testmode!`](@ref). | ||
|
||
Specify `rng` to use a custom RNG instead of the default. | ||
Custom RNGs are only supported on the CPU. | ||
By default every input is treated independently. With the `dims` keyword, | ||
instead it takes a random choice only along that dimension. | ||
For example `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input | ||
(also called 2D dropout). | ||
|
||
Does nothing to the input once [`Flux.testmode!`](@ref) is `true`. | ||
Keyword `rng` lets you specify a custom random number generator. | ||
(Only supported on the CPU.) | ||
|
||
# Examples | ||
```jldoctest | ||
julia> m = Chain(Dense(1 => 1), Dropout(1)); | ||
```julia | ||
julia> m = Chain(Dense(ones(3,2)), Dropout(0.4)) | ||
Chain( | ||
Dense(2 => 3), # 9 parameters | ||
Dropout(0.4), | ||
) | ||
|
||
julia> Flux.trainmode!(m); | ||
julia> m(ones(2, 7)) # test mode, no effect | ||
3×7 Matrix{Float64}: | ||
2.0 2.0 2.0 2.0 2.0 2.0 2.0 | ||
2.0 2.0 2.0 2.0 2.0 2.0 2.0 | ||
2.0 2.0 2.0 2.0 2.0 2.0 2.0 | ||
|
||
julia> y = m([1]); | ||
julia> Flux.trainmode!(m); # would happen within gradient | ||
|
||
julia> y == [0] | ||
true | ||
julia> m(ones(2, 7)) | ||
3×7 Matrix{Float64}: | ||
0.0 0.0 3.33333 0.0 0.0 0.0 0.0 | ||
3.33333 0.0 3.33333 0.0 3.33333 0.0 3.33333 | ||
3.33333 3.33333 0.0 3.33333 0.0 0.0 3.33333 | ||
|
||
julia> m = Chain(Dense(1000 => 1000), Dropout(0.5)); | ||
julia> y = m(ones(2, 10_000)); | ||
|
||
julia> Flux.trainmode!(m); | ||
julia> using Statistics | ||
|
||
julia> y = m(ones(1000)); | ||
julia> mean(y) # is about 2.0, as for test mode | ||
1.9892222222222182 | ||
|
||
julia> isapprox(count(==(0), y) / length(y), 0.5, atol=0.1) | ||
true | ||
julia> mean(iszero, y) # is about 0.4 | ||
0.40323333333333333 | ||
``` | ||
""" | ||
mutable struct Dropout{F,D,R<:AbstractRNG} | ||
mutable struct Dropout{F<:Real,D,R<:AbstractRNG} | ||
p::F | ||
dims::D | ||
active::Union{Bool, Nothing} | ||
rng::R | ||
end | ||
Dropout(p, dims, active) = Dropout(p, dims, active, default_rng_value()) | ||
Dropout(p::Real, dims, active) = Dropout(p, dims, active, default_rng_value()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if this is intentional but the error checking seems to only apply to the keyword based constructor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that's the only "public" one. I have no idea why we have this 3-arg constructor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that's my recollection. |
||
|
||
function Dropout(p; dims=:, rng = default_rng_value()) | ||
@assert 0 ≤ p ≤ 1 | ||
function Dropout(p::Real; dims=:, rng = default_rng_value()) | ||
0 ≤ p ≤ 1 || throw(ArgumentError("Dropout expects 0 ≤ p ≤ 1, got p = $p")) | ||
Dropout(p, dims, nothing, rng) | ||
end | ||
|
||
@functor Dropout | ||
trainable(a::Dropout) = (;) | ||
|
||
function (a::Dropout)(x) | ||
_isactive(a, x) || return x | ||
return dropout(a.rng, x, a.p; dims=a.dims, active=true) | ||
end | ||
(a::Dropout)(x) = dropout(a.rng, x, a.p * _isactive(a, x); dims=a.dims) | ||
|
||
testmode!(m::Dropout, mode=true) = | ||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) | ||
|
Uh oh!
There was an error while loading. Please reload this page.