Skip to content

Commit 0ed9f7d

Browse files
committed
Add norm functions
These roughly correspond to Flux's `*Norm` layers.
1 parent 4d73924 commit 0ed9f7d

File tree

5 files changed

+596
-2
lines changed

5 files changed

+596
-2
lines changed

src/NNlib.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@ module NNlib
33
using Pkg
44
using Requires
55
using ChainRulesCore
6-
import ChainRulesCore: rrule
6+
import ChainRulesCore: rrule, @ignore_derivatives
77
using Base.Broadcast: broadcasted
88
using Base.Threads
99
using Statistics
10-
using Statistics: mean
1110
using LinearAlgebra
1211
using LinearAlgebra: BlasFloat, Transpose, Adjoint, AdjOrTransAbsMat
1312
using LinearAlgebra.BLAS: BlasInt, @blasfunc
@@ -85,6 +84,7 @@ include("scatter.jl")
8584
include("utils.jl")
8685
include("sampling.jl")
8786
include("functions.jl")
87+
include("normalization.jl")
8888

8989
## Include implementations
9090
include("impl/padding_edges.jl")

src/normalization.jl

Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
"""
2+
norm_stats(x, dims)
3+
4+
Calculates sample mean and (uncorrected) variance of `x` along `dims`.
5+
6+
- `dims=(1,...,N-2,N)` for BatchNorm
7+
- `dims=(1,...,N-2)` for InstanceNorm and GroupNorm
8+
- `dims=(1,...,S)` where S < N for LayerNorm/Flux.jl/stable/
9+
10+
This is more efficient than calling `mean(x; dims)` and `var(x; dims)` separately,
11+
because it can share some computation across both.
12+
Implementors may want to overload this function to use custom kernels and more.
13+
"""
14+
function norm_stats(x, dims)
15+
μ = mean(x; dims)
16+
σ² = var(x; dims, mean = μ, corrected = false)
17+
return μ, σ²
18+
end
19+
20+
function rrule(::typeof(norm_stats), x, dims)
21+
μ, mean_pullback = rrule(mean, x; dims)
22+
σ², var_pullback = rrule(var, x; dims, mean = μ, corrected = false)
23+
function norm_stats_pullback(dargs)
24+
dμ, dσ² = unthunk(dargs)
25+
dx = ChainRulesCore.add!!(var_pullback(dμ)[2], mean_pullback(dσ²)[2])
26+
return (NoTangent(), dx, NoTangent())
27+
end
28+
return (μ, σ²), norm_stats_pullback
29+
end
30+
31+
_maybe_reshape(::Nothing, _) = nothing
32+
_maybe_reshape(x, dims) = reshape(x, dims)
33+
_apply_scale_bias(x, ::Nothing, ::Nothing) = x
34+
_apply_scale_bias(x, scale, bias) = x .* scale .+ bias
35+
36+
"""
37+
norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing},
38+
bias::Union{AbstractArray, Nothing}, ϵ::Real, affine_size = size(μ))
39+
40+
Shared code path for all built-in norm functions.
41+
42+
`μ` and `σ²` should be calculated on the fly using [`norm_stats`](@ref),
43+
or extracted from an existing collection such as [`RunningStats`](@ref).
44+
`bias` and `scale` are consistent with cuDNN and Flux.Scale.
45+
We opt for `scale` over `weight` to avoid confusion with dense layers.
46+
If the size of the statistics and affine parameters differ,
47+
use `affine_size` to add padding dimensions as required to match the input.
48+
"""
49+
function norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing},
50+
bias::Union{AbstractArray, Nothing}, ϵ::Real, affine_size = size(μ))
51+
@ignore_derivatives if isnothing(scale) != isnothing(bias)
52+
error("both scale and bias must be provided or left as nothing")
53+
end
54+
scale′, bias′ = _maybe_reshape(scale, affine_size), _maybe_reshape(bias, affine_size)
55+
return _apply_scale_bias((x .- μ) ./ sqrt.(σ² .+ ϵ), scale′, bias′)
56+
end
57+
58+
"""
59+
RunningStats(mean, variance, momentum)
60+
61+
Contains running mean and variance estimates for stateful norm functions.
62+
`momentum` controls the strength of the moving average update.
63+
64+
If the parameters are mutable, they will be updated in-place.
65+
Otherwise, they will be replaced wholesale.
66+
67+
See also [`update_running_stats!`](@ref).
68+
"""
69+
mutable struct RunningStats{M <: AbstractArray, V <: AbstractArray, MT <: Real}
70+
mean::M
71+
variance::V
72+
momentum::MT
73+
end
74+
75+
# Conditionally pulls running stats or calculates them on the fly.
76+
# Part of the reason this is a dedicated function is to have a more type stable pullback.
77+
function maybe_norm_stats(stats::Union{RunningStats, Nothing}, x, dims,
78+
use_running_stats::Bool)
79+
if stats !== nothing && use_running_stats
80+
# Maintains consistency with mean/var
81+
sz = Base.setindex(Base.reduced_indices(x, dims) |> Base.to_shape, :, ndims(x) - 1)
82+
return reshape(stats.mean, sz), reshape(stats.variance, sz)
83+
end
84+
# No running stats exist or are disabled in inference mode
85+
return norm_stats(x, dims)
86+
end
87+
88+
# Kludge so we can close over a Union inner pullback type
89+
struct MaybeNormStatsPullback{B, P <: ProjectTo{AbstractArray}}
90+
back::B
91+
projector::P
92+
end
93+
function (pb::MaybeNormStatsPullback)(dargs)
94+
_, dx = unthunk(pb.back(dargs))
95+
return (NoTangent(), NoTangent(), pb.projector(dx), NoTangent(), NoTangent())
96+
end
97+
function rrule(::typeof(maybe_norm_stats), stats::Union{RunningStats, Nothing}, x, dims,
98+
use_running_stats::Bool)
99+
project = ProjectTo(x)
100+
noop_back(_) = (NoTangent(), NoTangent())
101+
if stats === nothing || !use_running_stats
102+
(μ, σ²), back = rrule(norm_stats, x, dims)
103+
else
104+
# The default is to track, so this only happens when a layer is frozen
105+
sz = Base.setindex(Base.reduced_indices(x, dims) |> Base.to_shape, :, ndims(x) - 1)
106+
μ, σ², back = reshape(stats.mean, sz), reshape(stats.variance, sz), noop_back
107+
end
108+
back_type = Union{typeof(noop_back), _rrule_pullback_rt(norm_stats, x, dims)}
109+
return (μ, σ²), MaybeNormStatsPullback{back_type, typeof(project)}(back, project)
110+
end
111+
112+
"""
113+
update_running_stats!(stats::RunningStats, x::AbstractArray{<:Any, N}, μ, σ²,
114+
reduce_dims) where {N}
115+
116+
Performs a moving average update for layers with tracked statistics.
117+
`μ` and `σ²` are the sample mean and variance, most likely from [`norm_stats`](@ref).
118+
`reduce_dims` should also match the `dims` argument of [`norm_stats`](@ref).
119+
120+
See also [`RunningStats`](@ref).
121+
"""
122+
function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Dims)
123+
V = eltype(σ²)
124+
momentum = stats.momentum
125+
res_mtm = one(V) - momentum
126+
m = prod(size(x, i) for i in reduce_dims)
127+
correction = m / (m - one(V))
128+
129+
running_mean, running_var = stats.mean, stats.variance
130+
if ChainRulesCore.is_inplaceable_destination(running_mean)
131+
stats.mean .= res_mtm .* running_mean .+ momentum .* vec(μ)
132+
else
133+
stats.mean = res_mtm .* running_mean .+ momentum .* vec(μ)
134+
end
135+
if ChainRulesCore.is_inplaceable_destination(running_var)
136+
stats.variance .= res_mtm .* running_var .+ momentum .* correction .* vec(σ²)
137+
else
138+
stats.variance = res_mtm .* running_var .+ momentum .* correction .* vec(σ²)
139+
end
140+
end
141+
142+
# Convenience functions
143+
# We follow roughly the same arg order as torch.nn.functional.*_norm:
144+
# input, unique args for this particular norm type, bias + scale, eps; kwargs...
145+
146+
"""
147+
layernorm(x::AbstractArray{<:Any,N}, ::Val{S}, scale = nothing, bias = nothing,
148+
ϵ=ofeltype(x, 1e-5)) where {N, S}
149+
150+
Functional [Layer Normalization](https://arxiv.org/abs/1607.06450) operation.
151+
152+
Normalizes `x` along the first `S` dimensions.
153+
154+
For an additional learned affine transform, provide a `S`-dimensional `scale` and `bias`.
155+
156+
See also [`batchnorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref).
157+
158+
# Examples
159+
160+
```jldoctest
161+
julia> using Statistics
162+
163+
julia> xs = rand(3, 3, 3, 2); # a batch of 2 images, each having 3 channels
164+
165+
julia> y = NNlib.layernorm(xs, Val(3));
166+
167+
julia> isapprox(std(y; dims = 1:3), ones(1, 1, 1, 2); atol = 0.1) &&
168+
std(y; dims = 1:3) != std(xs; dims = 1:3)
169+
true
170+
```
171+
"""
172+
function layernorm(x::AbstractArray{<:Any, N}, ::Val{S}, scale = nothing, bias = nothing,
173+
ϵ = ofeltype(x, 1e-5)) where {N, S}
174+
@ignore_derivatives if S > N
175+
throw(DimensionMismatch("got $S reduction dims for $N-dimensional array"))
176+
end
177+
μ, σ² = norm_stats(x, ntuple(identity, S))
178+
return norm_helper(x, μ, σ², scale, bias, ϵ, size(x)[1:S])
179+
end
180+
181+
"""
182+
batchnorm(x::AbstractArray{<:Any, N},
183+
running_stats::Union{RunningStats, Nothing} = nothing,
184+
scale::Union{AbstractVector, Nothing} = nothing,
185+
bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5);
186+
training::Bool = within_grad()) where {N}
187+
188+
Functional [Batch Normalization](https://arxiv.org/abs/1502.03167) operation.
189+
190+
Normalizes `x` along each ``D_1×...×D_{N-2}×1×D_N`` input slice,
191+
where `N-1` is the "channel" (or "feature", for 2D inputs) dimension.
192+
193+
Provide a [`RunningStats`](@ref) to fix a estimated mean and variance.
194+
`batchnorm` will renormalize the input using these statistics during inference,
195+
and update them using batch-level statistics when training.
196+
To override this behaviour, manually set a value for `training`.
197+
198+
If specified, `scale` and `bias` will be applied as an additional learned affine transform.
199+
200+
See also [`layernorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref).
201+
"""
202+
function batchnorm(x::AbstractArray{<:Any, N},
203+
running_stats::Union{RunningStats, Nothing} = nothing,
204+
scale::Union{AbstractVector, Nothing} = nothing,
205+
bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5);
206+
training::Bool = within_grad()) where {N}
207+
reduce_dims = ((1:(N - 2))..., N)
208+
μ, σ² = maybe_norm_stats(running_stats, x, reduce_dims, !training)
209+
# Because μ and σ² could be updated in-place, we compute the output first
210+
y = norm_helper(x, μ, σ², scale, bias, ϵ)
211+
@ignore_derivatives if running_stats !== nothing && training
212+
update_running_stats!(running_stats, x, μ, σ², reduce_dims)
213+
end
214+
return y
215+
end
216+
217+
"""
218+
instancenorm(x::AbstractArray{<:Any, N},
219+
running_stats::Union{RunningStats, Nothing} = nothing,
220+
scale::Union{AbstractVector, Nothing} = nothing,
221+
bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5);
222+
training::Bool = within_grad()) where {N}
223+
224+
Functional [Instance Normalization](https://arxiv.org/abs/1607.08022) operation.
225+
226+
Normalizes `x` along each ``D_1×...×D_{N-2}×1×1`` input slice,
227+
228+
Provide a [`RunningStats`](@ref) to fix a estimated mean and variance.
229+
`instancenorm` will renormalize the input using these statistics during inference,
230+
and update them using batch-level statistics when training.
231+
To override this behaviour, manually set a value for `training`.
232+
233+
If specified, `scale` and `bias` will be applied as an additional learned affine transform.
234+
235+
See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`groupnorm`](@ref).
236+
"""
237+
function instancenorm(x::AbstractArray{<:Any, N},
238+
running_stats::Union{RunningStats, Nothing} = nothing,
239+
scale::Union{AbstractVector, Nothing} = nothing,
240+
bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5);
241+
training::Bool = within_grad()) where {N}
242+
affine_size = (ntuple(_ -> 1, N - 2)..., size(x, N - 1), :)
243+
reduce_dims = ((1:(N - 2))...,)
244+
μ, σ² = maybe_norm_stats(running_stats, x, reduce_dims, !training)
245+
# Because μ and σ² could be updated in-place, we compute the output first
246+
y = norm_helper(x, μ, σ², scale, bias, ϵ, affine_size)
247+
ChainRulesCore.@ignore_derivatives if running_stats !== nothing && training
248+
μ′, σ²′ = mean(μ; dims = N), mean(σ²; dims = N) # Need to sum (C, N) -> (C,)
249+
update_running_stats!(running_stats, x, μ′, σ²′, reduce_dims)
250+
end
251+
return y
252+
end
253+
254+
"""
255+
groupnorm(x::AbstractArray{<:Any, N}, groups::Integer,
256+
scale::Union{AbstractVector, Nothing} = nothing,
257+
bias::Union{AbstractVector, Nothing} = nothing,
258+
ϵ = ofeltype(x, 1e-5)) where {N}
259+
260+
Functional [Group Normalization](https://arxiv.org/abs/1803.08494) operation.
261+
262+
Normalizes `x` along the first `N - 2` (spatial) dimensions,
263+
where `N-1` is the "channel" (or "feature", for 2D inputs) dimension,
264+
and the channel dimension is divided into `groups` groups along which statistics are computed.
265+
The number of channels must be an integer multiple of the number of groups.
266+
267+
If specified, `scale` and `bias` will be applied as an additional learned affine transform.
268+
269+
See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`instancenorm`](@ref).
270+
271+
# Examples
272+
273+
```jldoctest
274+
julia> using Statistics
275+
276+
julia> xs = rand(3, 3, 4, 2); # a batch of 2 images, each having 4 channels
277+
278+
julia> y = NNlib.groupnorm(xs, 4);
279+
280+
julia> isapprox(std(y[:, :, 1:2, 1]), 1; atol = 0.1) &&
281+
std(xs[:, :, 1:2, 1]) != std(y[:, :, 1:2, 1])
282+
true
283+
284+
julia> isapprox(std(y[:, :, 3:4, 2]), 1; atol = 0.1) &&
285+
std(xs[:, :, 3:4, 2]) != std(y[:, :, 3:4, 2])
286+
true
287+
```
288+
"""
289+
function groupnorm(x::AbstractArray{<:Any, N}, groups::Integer,
290+
scale::Union{AbstractVector, Nothing} = nothing,
291+
bias::Union{AbstractVector, Nothing} = nothing,
292+
ϵ = ofeltype(x, 1e-5)) where {N}
293+
sz = size(x)
294+
channels = @ignore_derivatives begin
295+
ch = sz[max(1, N - 1)]
296+
newch, remainder = divrem(ch, groups)
297+
remainder == 0 ? newch :
298+
throw(ArgumentError("channels $ch should be multiple of groups $groups"))
299+
end
300+
affine_size = (ntuple(_ -> 1, N - 2)..., channels, groups, :)
301+
grouped_size = (sz[1:(N - 2)]..., channels, groups, :)
302+
x′ = reshape(x, grouped_size)
303+
μ, σ² = norm_stats(x′, ((1:(N - 2))...,))
304+
return reshape(norm_helper(x′, μ, σ², scale, bias, ϵ, affine_size), sz)
305+
end

src/utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,15 @@ function reverse_indices(idx::AbstractArray{<:Any,N}) where N
5454
end
5555

5656
unsqueeze(x) = reshape(x, 1, size(x)...)
57+
58+
# This is a terrible hack to prevent the spread of type instabilities
59+
# when the pullback type changes depending on runtime information,
60+
# e.g. when a normalization layer is "active" vs "inactive".
61+
function _rrule_pullback_rt(@nospecialize(fn), args...)
62+
rt = Base.promote_op(rrule, typeof(fn), map(typeof, args)...)
63+
rt <: Tuple{<:Any,<:Any} && return rt.parameters[2]
64+
return rt
65+
end
66+
67+
# Extracted from Flux. Should this have a docstring and/or be in the docs?
68+
ofeltype(x, y) = convert(float(eltype(x)), y)

0 commit comments

Comments
 (0)