|
| 1 | +""" |
| 2 | + norm_stats(x, dims) |
| 3 | +
|
| 4 | +Calculates sample mean and (uncorrected) variance of `x` along `dims`. |
| 5 | +
|
| 6 | + - `dims=(1,...,N-2,N)` for BatchNorm |
| 7 | + - `dims=(1,...,N-2)` for InstanceNorm and GroupNorm |
| 8 | + - `dims=(1,...,S)` where S < N for LayerNorm/Flux.jl/stable/ |
| 9 | +
|
| 10 | +This is more efficient than calling `mean(x; dims)` and `var(x; dims)` separately, |
| 11 | +because it can share some computation across both. |
| 12 | +Implementors may want to overload this function to use custom kernels and more. |
| 13 | +""" |
| 14 | +function norm_stats(x, dims) |
| 15 | + μ = mean(x; dims) |
| 16 | + σ² = var(x; dims, mean = μ, corrected = false) |
| 17 | + return μ, σ² |
| 18 | +end |
| 19 | + |
| 20 | +function rrule(::typeof(norm_stats), x, dims) |
| 21 | + μ, mean_pullback = rrule(mean, x; dims) |
| 22 | + σ², var_pullback = rrule(var, x; dims, mean = μ, corrected = false) |
| 23 | + function norm_stats_pullback(dargs) |
| 24 | + dμ, dσ² = unthunk(dargs) |
| 25 | + dx = ChainRulesCore.add!!(var_pullback(dμ)[2], mean_pullback(dσ²)[2]) |
| 26 | + return (NoTangent(), dx, NoTangent()) |
| 27 | + end |
| 28 | + return (μ, σ²), norm_stats_pullback |
| 29 | +end |
| 30 | + |
| 31 | +_maybe_reshape(::Nothing, _) = nothing |
| 32 | +_maybe_reshape(x, dims) = reshape(x, dims) |
| 33 | +_apply_scale_bias(x, ::Nothing, ::Nothing) = x |
| 34 | +_apply_scale_bias(x, scale, bias) = x .* scale .+ bias |
| 35 | + |
| 36 | +""" |
| 37 | + norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing}, |
| 38 | + bias::Union{AbstractArray, Nothing}, ϵ::Real, affine_size = size(μ)) |
| 39 | +
|
| 40 | +Shared code path for all built-in norm functions. |
| 41 | +
|
| 42 | +`μ` and `σ²` should be calculated on the fly using [`norm_stats`](@ref), |
| 43 | +or extracted from an existing collection such as [`RunningStats`](@ref). |
| 44 | +`bias` and `scale` are consistent with cuDNN and Flux.Scale. |
| 45 | +We opt for `scale` over `weight` to avoid confusion with dense layers. |
| 46 | +If the size of the statistics and affine parameters differ, |
| 47 | +use `affine_size` to add padding dimensions as required to match the input. |
| 48 | +""" |
| 49 | +function norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing}, |
| 50 | + bias::Union{AbstractArray, Nothing}, ϵ::Real, affine_size = size(μ)) |
| 51 | + @ignore_derivatives if isnothing(scale) != isnothing(bias) |
| 52 | + error("both scale and bias must be provided or left as nothing") |
| 53 | + end |
| 54 | + scale′, bias′ = _maybe_reshape(scale, affine_size), _maybe_reshape(bias, affine_size) |
| 55 | + return _apply_scale_bias((x .- μ) ./ sqrt.(σ² .+ ϵ), scale′, bias′) |
| 56 | +end |
| 57 | + |
| 58 | +""" |
| 59 | + RunningStats(mean, variance, momentum) |
| 60 | +
|
| 61 | +Contains running mean and variance estimates for stateful norm functions. |
| 62 | +`momentum` controls the strength of the moving average update. |
| 63 | +
|
| 64 | +If the parameters are mutable, they will be updated in-place. |
| 65 | +Otherwise, they will be replaced wholesale. |
| 66 | +
|
| 67 | +See also [`update_running_stats!`](@ref). |
| 68 | +""" |
| 69 | +mutable struct RunningStats{M <: AbstractArray, V <: AbstractArray, MT <: Real} |
| 70 | + mean::M |
| 71 | + variance::V |
| 72 | + momentum::MT |
| 73 | +end |
| 74 | + |
| 75 | +# Conditionally pulls running stats or calculates them on the fly. |
| 76 | +# Part of the reason this is a dedicated function is to have a more type stable pullback. |
| 77 | +function maybe_norm_stats(stats::Union{RunningStats, Nothing}, x, dims, |
| 78 | + use_running_stats::Bool) |
| 79 | + if stats !== nothing && use_running_stats |
| 80 | + # Maintains consistency with mean/var |
| 81 | + sz = Base.setindex(Base.reduced_indices(x, dims) |> Base.to_shape, :, ndims(x) - 1) |
| 82 | + return reshape(stats.mean, sz), reshape(stats.variance, sz) |
| 83 | + end |
| 84 | + # No running stats exist or are disabled in inference mode |
| 85 | + return norm_stats(x, dims) |
| 86 | +end |
| 87 | + |
| 88 | +# Kludge so we can close over a Union inner pullback type |
| 89 | +struct MaybeNormStatsPullback{B, P <: ProjectTo{AbstractArray}} |
| 90 | + back::B |
| 91 | + projector::P |
| 92 | +end |
| 93 | +function (pb::MaybeNormStatsPullback)(dargs) |
| 94 | + _, dx = unthunk(pb.back(dargs)) |
| 95 | + return (NoTangent(), NoTangent(), pb.projector(dx), NoTangent(), NoTangent()) |
| 96 | +end |
| 97 | +function rrule(::typeof(maybe_norm_stats), stats::Union{RunningStats, Nothing}, x, dims, |
| 98 | + use_running_stats::Bool) |
| 99 | + project = ProjectTo(x) |
| 100 | + noop_back(_) = (NoTangent(), NoTangent()) |
| 101 | + if stats === nothing || !use_running_stats |
| 102 | + (μ, σ²), back = rrule(norm_stats, x, dims) |
| 103 | + else |
| 104 | + # The default is to track, so this only happens when a layer is frozen |
| 105 | + sz = Base.setindex(Base.reduced_indices(x, dims) |> Base.to_shape, :, ndims(x) - 1) |
| 106 | + μ, σ², back = reshape(stats.mean, sz), reshape(stats.variance, sz), noop_back |
| 107 | + end |
| 108 | + back_type = Union{typeof(noop_back), _rrule_pullback_rt(norm_stats, x, dims)} |
| 109 | + return (μ, σ²), MaybeNormStatsPullback{back_type, typeof(project)}(back, project) |
| 110 | +end |
| 111 | + |
| 112 | +""" |
| 113 | + update_running_stats!(stats::RunningStats, x::AbstractArray{<:Any, N}, μ, σ², |
| 114 | + reduce_dims) where {N} |
| 115 | +
|
| 116 | +Performs a moving average update for layers with tracked statistics. |
| 117 | +`μ` and `σ²` are the sample mean and variance, most likely from [`norm_stats`](@ref). |
| 118 | +`reduce_dims` should also match the `dims` argument of [`norm_stats`](@ref). |
| 119 | +
|
| 120 | +See also [`RunningStats`](@ref). |
| 121 | +""" |
| 122 | +function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Dims) |
| 123 | + V = eltype(σ²) |
| 124 | + momentum = stats.momentum |
| 125 | + res_mtm = one(V) - momentum |
| 126 | + m = prod(size(x, i) for i in reduce_dims) |
| 127 | + correction = m / (m - one(V)) |
| 128 | + |
| 129 | + running_mean, running_var = stats.mean, stats.variance |
| 130 | + if ChainRulesCore.is_inplaceable_destination(running_mean) |
| 131 | + stats.mean .= res_mtm .* running_mean .+ momentum .* vec(μ) |
| 132 | + else |
| 133 | + stats.mean = res_mtm .* running_mean .+ momentum .* vec(μ) |
| 134 | + end |
| 135 | + if ChainRulesCore.is_inplaceable_destination(running_var) |
| 136 | + stats.variance .= res_mtm .* running_var .+ momentum .* correction .* vec(σ²) |
| 137 | + else |
| 138 | + stats.variance = res_mtm .* running_var .+ momentum .* correction .* vec(σ²) |
| 139 | + end |
| 140 | +end |
| 141 | + |
| 142 | +# Convenience functions |
| 143 | +# We follow roughly the same arg order as torch.nn.functional.*_norm: |
| 144 | +# input, unique args for this particular norm type, bias + scale, eps; kwargs... |
| 145 | + |
| 146 | +""" |
| 147 | + layernorm(x::AbstractArray{<:Any,N}, ::Val{S}, scale = nothing, bias = nothing, |
| 148 | + ϵ=ofeltype(x, 1e-5)) where {N, S} |
| 149 | +
|
| 150 | +Functional [Layer Normalization](https://arxiv.org/abs/1607.06450) operation. |
| 151 | +
|
| 152 | +Normalizes `x` along the first `S` dimensions. |
| 153 | +
|
| 154 | +For an additional learned affine transform, provide a `S`-dimensional `scale` and `bias`. |
| 155 | +
|
| 156 | +See also [`batchnorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref). |
| 157 | +
|
| 158 | +# Examples |
| 159 | +
|
| 160 | +```jldoctest |
| 161 | +julia> using Statistics |
| 162 | +
|
| 163 | +julia> xs = rand(3, 3, 3, 2); # a batch of 2 images, each having 3 channels |
| 164 | +
|
| 165 | +julia> y = NNlib.layernorm(xs, Val(3)); |
| 166 | +
|
| 167 | +julia> isapprox(std(y; dims = 1:3), ones(1, 1, 1, 2); atol = 0.1) && |
| 168 | + std(y; dims = 1:3) != std(xs; dims = 1:3) |
| 169 | +true |
| 170 | +``` |
| 171 | +""" |
| 172 | +function layernorm(x::AbstractArray{<:Any, N}, ::Val{S}, scale = nothing, bias = nothing, |
| 173 | + ϵ = ofeltype(x, 1e-5)) where {N, S} |
| 174 | + @ignore_derivatives if S > N |
| 175 | + throw(DimensionMismatch("got $S reduction dims for $N-dimensional array")) |
| 176 | + end |
| 177 | + μ, σ² = norm_stats(x, ntuple(identity, S)) |
| 178 | + return norm_helper(x, μ, σ², scale, bias, ϵ, size(x)[1:S]) |
| 179 | +end |
| 180 | + |
| 181 | +""" |
| 182 | + batchnorm(x::AbstractArray{<:Any, N}, |
| 183 | + running_stats::Union{RunningStats, Nothing} = nothing, |
| 184 | + scale::Union{AbstractVector, Nothing} = nothing, |
| 185 | + bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5); |
| 186 | + training::Bool = within_grad()) where {N} |
| 187 | +
|
| 188 | +Functional [Batch Normalization](https://arxiv.org/abs/1502.03167) operation. |
| 189 | +
|
| 190 | +Normalizes `x` along each ``D_1×...×D_{N-2}×1×D_N`` input slice, |
| 191 | +where `N-1` is the "channel" (or "feature", for 2D inputs) dimension. |
| 192 | +
|
| 193 | +Provide a [`RunningStats`](@ref) to fix a estimated mean and variance. |
| 194 | +`batchnorm` will renormalize the input using these statistics during inference, |
| 195 | +and update them using batch-level statistics when training. |
| 196 | +To override this behaviour, manually set a value for `training`. |
| 197 | +
|
| 198 | +If specified, `scale` and `bias` will be applied as an additional learned affine transform. |
| 199 | +
|
| 200 | +See also [`layernorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref). |
| 201 | +""" |
| 202 | +function batchnorm(x::AbstractArray{<:Any, N}, |
| 203 | + running_stats::Union{RunningStats, Nothing} = nothing, |
| 204 | + scale::Union{AbstractVector, Nothing} = nothing, |
| 205 | + bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5); |
| 206 | + training::Bool = within_grad()) where {N} |
| 207 | + reduce_dims = ((1:(N - 2))..., N) |
| 208 | + μ, σ² = maybe_norm_stats(running_stats, x, reduce_dims, !training) |
| 209 | + # Because μ and σ² could be updated in-place, we compute the output first |
| 210 | + y = norm_helper(x, μ, σ², scale, bias, ϵ) |
| 211 | + @ignore_derivatives if running_stats !== nothing && training |
| 212 | + update_running_stats!(running_stats, x, μ, σ², reduce_dims) |
| 213 | + end |
| 214 | + return y |
| 215 | +end |
| 216 | + |
| 217 | +""" |
| 218 | + instancenorm(x::AbstractArray{<:Any, N}, |
| 219 | + running_stats::Union{RunningStats, Nothing} = nothing, |
| 220 | + scale::Union{AbstractVector, Nothing} = nothing, |
| 221 | + bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5); |
| 222 | + training::Bool = within_grad()) where {N} |
| 223 | +
|
| 224 | +Functional [Instance Normalization](https://arxiv.org/abs/1607.08022) operation. |
| 225 | +
|
| 226 | +Normalizes `x` along each ``D_1×...×D_{N-2}×1×1`` input slice, |
| 227 | +
|
| 228 | +Provide a [`RunningStats`](@ref) to fix a estimated mean and variance. |
| 229 | +`instancenorm` will renormalize the input using these statistics during inference, |
| 230 | +and update them using batch-level statistics when training. |
| 231 | +To override this behaviour, manually set a value for `training`. |
| 232 | +
|
| 233 | +If specified, `scale` and `bias` will be applied as an additional learned affine transform. |
| 234 | +
|
| 235 | +See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`groupnorm`](@ref). |
| 236 | +""" |
| 237 | +function instancenorm(x::AbstractArray{<:Any, N}, |
| 238 | + running_stats::Union{RunningStats, Nothing} = nothing, |
| 239 | + scale::Union{AbstractVector, Nothing} = nothing, |
| 240 | + bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5); |
| 241 | + training::Bool = within_grad()) where {N} |
| 242 | + affine_size = (ntuple(_ -> 1, N - 2)..., size(x, N - 1), :) |
| 243 | + reduce_dims = ((1:(N - 2))...,) |
| 244 | + μ, σ² = maybe_norm_stats(running_stats, x, reduce_dims, !training) |
| 245 | + # Because μ and σ² could be updated in-place, we compute the output first |
| 246 | + y = norm_helper(x, μ, σ², scale, bias, ϵ, affine_size) |
| 247 | + ChainRulesCore.@ignore_derivatives if running_stats !== nothing && training |
| 248 | + μ′, σ²′ = mean(μ; dims = N), mean(σ²; dims = N) # Need to sum (C, N) -> (C,) |
| 249 | + update_running_stats!(running_stats, x, μ′, σ²′, reduce_dims) |
| 250 | + end |
| 251 | + return y |
| 252 | +end |
| 253 | + |
| 254 | +""" |
| 255 | + groupnorm(x::AbstractArray{<:Any, N}, groups::Integer, |
| 256 | + scale::Union{AbstractVector, Nothing} = nothing, |
| 257 | + bias::Union{AbstractVector, Nothing} = nothing, |
| 258 | + ϵ = ofeltype(x, 1e-5)) where {N} |
| 259 | +
|
| 260 | +Functional [Group Normalization](https://arxiv.org/abs/1803.08494) operation. |
| 261 | +
|
| 262 | +Normalizes `x` along the first `N - 2` (spatial) dimensions, |
| 263 | +where `N-1` is the "channel" (or "feature", for 2D inputs) dimension, |
| 264 | +and the channel dimension is divided into `groups` groups along which statistics are computed. |
| 265 | +The number of channels must be an integer multiple of the number of groups. |
| 266 | +
|
| 267 | +If specified, `scale` and `bias` will be applied as an additional learned affine transform. |
| 268 | +
|
| 269 | +See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`instancenorm`](@ref). |
| 270 | +
|
| 271 | +# Examples |
| 272 | +
|
| 273 | +```jldoctest |
| 274 | +julia> using Statistics |
| 275 | +
|
| 276 | +julia> xs = rand(3, 3, 4, 2); # a batch of 2 images, each having 4 channels |
| 277 | +
|
| 278 | +julia> y = NNlib.groupnorm(xs, 4); |
| 279 | +
|
| 280 | +julia> isapprox(std(y[:, :, 1:2, 1]), 1; atol = 0.1) && |
| 281 | + std(xs[:, :, 1:2, 1]) != std(y[:, :, 1:2, 1]) |
| 282 | +true |
| 283 | +
|
| 284 | +julia> isapprox(std(y[:, :, 3:4, 2]), 1; atol = 0.1) && |
| 285 | + std(xs[:, :, 3:4, 2]) != std(y[:, :, 3:4, 2]) |
| 286 | +true |
| 287 | +``` |
| 288 | +""" |
| 289 | +function groupnorm(x::AbstractArray{<:Any, N}, groups::Integer, |
| 290 | + scale::Union{AbstractVector, Nothing} = nothing, |
| 291 | + bias::Union{AbstractVector, Nothing} = nothing, |
| 292 | + ϵ = ofeltype(x, 1e-5)) where {N} |
| 293 | + sz = size(x) |
| 294 | + channels = @ignore_derivatives begin |
| 295 | + ch = sz[max(1, N - 1)] |
| 296 | + newch, remainder = divrem(ch, groups) |
| 297 | + remainder == 0 ? newch : |
| 298 | + throw(ArgumentError("channels $ch should be multiple of groups $groups")) |
| 299 | + end |
| 300 | + affine_size = (ntuple(_ -> 1, N - 2)..., channels, groups, :) |
| 301 | + grouped_size = (sz[1:(N - 2)]..., channels, groups, :) |
| 302 | + x′ = reshape(x, grouped_size) |
| 303 | + μ, σ² = norm_stats(x′, ((1:(N - 2))...,)) |
| 304 | + return reshape(norm_helper(x′, μ, σ², scale, bias, ϵ, affine_size), sz) |
| 305 | +end |
0 commit comments