Skip to content

Commit c118028

Browse files
committed
Add norm functions
These roughly correspond to Flux's `*Norm` layers.
1 parent aea063c commit c118028

File tree

5 files changed

+596
-2
lines changed

5 files changed

+596
-2
lines changed

src/NNlib.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module NNlib
22

33
import Atomix
4-
import ChainRulesCore: rrule
4+
import ChainRulesCore: rrule, @ignore_derivatives
55

66
using Base.Broadcast: broadcasted
77
using Base.Threads
@@ -16,7 +16,6 @@ using Pkg
1616
using Random
1717
using Requires
1818
using Statistics
19-
using Statistics: mean
2019

2120
const libblas = Base.libblas_name
2221

src/normalization.jl

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,309 @@ function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, runnin
1212
end
1313
y, batchnorm_pullback
1414
end
15+
16+
"""
17+
norm_stats(x, dims)
18+
19+
Calculates sample mean and (uncorrected) variance of `x` along `dims`.
20+
21+
- `dims=(1,...,N-2,N)` for BatchNorm
22+
- `dims=(1,...,N-2)` for InstanceNorm and GroupNorm
23+
- `dims=(1,...,S)` where S < N for LayerNorm/Flux.jl/stable/
24+
25+
This is more efficient than calling `mean(x; dims)` and `var(x; dims)` separately,
26+
because it can share some computation across both.
27+
Implementors may want to overload this function to use custom kernels and more.
28+
"""
29+
function norm_stats(x, dims)
30+
μ = mean(x; dims)
31+
σ² = var(x; dims, mean = μ, corrected = false)
32+
return μ, σ²
33+
end
34+
35+
function rrule(::typeof(norm_stats), x, dims)
36+
μ, mean_pullback = rrule(mean, x; dims)
37+
σ², var_pullback = rrule(var, x; dims, mean = μ, corrected = false)
38+
function norm_stats_pullback(dargs)
39+
dμ, dσ² = unthunk(dargs)
40+
dx = ChainRulesCore.add!!(var_pullback(dμ)[2], mean_pullback(dσ²)[2])
41+
return (NoTangent(), dx, NoTangent())
42+
end
43+
return (μ, σ²), norm_stats_pullback
44+
end
45+
46+
_maybe_reshape(::Nothing, _) = nothing
47+
_maybe_reshape(x, dims) = reshape(x, dims)
48+
_apply_scale_bias(x, ::Nothing, ::Nothing) = x
49+
_apply_scale_bias(x, scale, bias) = x .* scale .+ bias
50+
51+
"""
52+
norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing},
53+
bias::Union{AbstractArray, Nothing}, ϵ::Real, affine_size = size(μ))
54+
55+
Shared code path for all built-in norm functions.
56+
57+
`μ` and `σ²` should be calculated on the fly using [`norm_stats`](@ref),
58+
or extracted from an existing collection such as [`RunningStats`](@ref).
59+
`bias` and `scale` are consistent with cuDNN and Flux.Scale.
60+
We opt for `scale` over `weight` to avoid confusion with dense layers.
61+
If the size of the statistics and affine parameters differ,
62+
use `affine_size` to add padding dimensions as required to match the input.
63+
"""
64+
function norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing},
65+
bias::Union{AbstractArray, Nothing}, ϵ::Real, affine_size = size(μ))
66+
@ignore_derivatives if isnothing(scale) != isnothing(bias)
67+
error("both scale and bias must be provided or left as nothing")
68+
end
69+
scale′, bias′ = _maybe_reshape(scale, affine_size), _maybe_reshape(bias, affine_size)
70+
return _apply_scale_bias((x .- μ) ./ sqrt.(σ² .+ ϵ), scale′, bias′)
71+
end
72+
73+
"""
74+
RunningStats(mean, variance, momentum)
75+
76+
Contains running mean and variance estimates for stateful norm functions.
77+
`momentum` controls the strength of the moving average update.
78+
79+
If the parameters are mutable, they will be updated in-place.
80+
Otherwise, they will be replaced wholesale.
81+
82+
See also [`update_running_stats!`](@ref).
83+
"""
84+
mutable struct RunningStats{M <: AbstractArray, V <: AbstractArray, MT <: Real}
85+
mean::M
86+
variance::V
87+
momentum::MT
88+
end
89+
90+
# Conditionally pulls running stats or calculates them on the fly.
91+
# Part of the reason this is a dedicated function is to have a more type stable pullback.
92+
function maybe_norm_stats(stats::Union{RunningStats, Nothing}, x, dims,
93+
use_running_stats::Bool)
94+
if stats !== nothing && use_running_stats
95+
# Maintains consistency with mean/var
96+
sz = Base.setindex(Base.reduced_indices(x, dims) |> Base.to_shape, :, ndims(x) - 1)
97+
return reshape(stats.mean, sz), reshape(stats.variance, sz)
98+
end
99+
# No running stats exist or are disabled in inference mode
100+
return norm_stats(x, dims)
101+
end
102+
103+
# Kludge so we can close over a Union inner pullback type
104+
struct MaybeNormStatsPullback{B, P <: ProjectTo{AbstractArray}}
105+
back::B
106+
projector::P
107+
end
108+
function (pb::MaybeNormStatsPullback)(dargs)
109+
_, dx = unthunk(pb.back(dargs))
110+
return (NoTangent(), NoTangent(), pb.projector(dx), NoTangent(), NoTangent())
111+
end
112+
function rrule(::typeof(maybe_norm_stats), stats::Union{RunningStats, Nothing}, x, dims,
113+
use_running_stats::Bool)
114+
project = ProjectTo(x)
115+
noop_back(_) = (NoTangent(), NoTangent())
116+
if stats === nothing || !use_running_stats
117+
(μ, σ²), back = rrule(norm_stats, x, dims)
118+
else
119+
# The default is to track, so this only happens when a layer is frozen
120+
sz = Base.setindex(Base.reduced_indices(x, dims) |> Base.to_shape, :, ndims(x) - 1)
121+
μ, σ², back = reshape(stats.mean, sz), reshape(stats.variance, sz), noop_back
122+
end
123+
back_type = Union{typeof(noop_back), _rrule_pullback_rt(norm_stats, x, dims)}
124+
return (μ, σ²), MaybeNormStatsPullback{back_type, typeof(project)}(back, project)
125+
end
126+
127+
"""
128+
update_running_stats!(stats::RunningStats, x::AbstractArray{<:Any, N}, μ, σ²,
129+
reduce_dims) where {N}
130+
131+
Performs a moving average update for layers with tracked statistics.
132+
`μ` and `σ²` are the sample mean and variance, most likely from [`norm_stats`](@ref).
133+
`reduce_dims` should also match the `dims` argument of [`norm_stats`](@ref).
134+
135+
See also [`RunningStats`](@ref).
136+
"""
137+
function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Dims)
138+
V = eltype(σ²)
139+
momentum = stats.momentum
140+
res_mtm = one(V) - momentum
141+
m = prod(size(x, i) for i in reduce_dims)
142+
correction = m / (m - one(V))
143+
144+
running_mean, running_var = stats.mean, stats.variance
145+
if ChainRulesCore.is_inplaceable_destination(running_mean)
146+
stats.mean .= res_mtm .* running_mean .+ momentum .* vec(μ)
147+
else
148+
stats.mean = res_mtm .* running_mean .+ momentum .* vec(μ)
149+
end
150+
if ChainRulesCore.is_inplaceable_destination(running_var)
151+
stats.variance .= res_mtm .* running_var .+ momentum .* correction .* vec(σ²)
152+
else
153+
stats.variance = res_mtm .* running_var .+ momentum .* correction .* vec(σ²)
154+
end
155+
end
156+
157+
# Convenience functions
158+
# We follow roughly the same arg order as torch.nn.functional.*_norm:
159+
# input, unique args for this particular norm type, bias + scale, eps; kwargs...
160+
161+
"""
162+
layernorm(x::AbstractArray{<:Any,N}, ::Val{S}, scale = nothing, bias = nothing,
163+
ϵ=ofeltype(x, 1e-5)) where {N, S}
164+
165+
Functional [Layer Normalization](https://arxiv.org/abs/1607.06450) operation.
166+
167+
Normalizes `x` along the first `S` dimensions.
168+
169+
For an additional learned affine transform, provide a `S`-dimensional `scale` and `bias`.
170+
171+
See also [`batchnorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref).
172+
173+
# Examples
174+
175+
```jldoctest
176+
julia> using Statistics
177+
178+
julia> xs = rand(3, 3, 3, 2); # a batch of 2 images, each having 3 channels
179+
180+
julia> y = NNlib.layernorm(xs, Val(3));
181+
182+
julia> isapprox(std(y; dims = 1:3), ones(1, 1, 1, 2); atol = 0.1) &&
183+
std(y; dims = 1:3) != std(xs; dims = 1:3)
184+
true
185+
```
186+
"""
187+
function layernorm(x::AbstractArray{<:Any, N}, ::Val{S}, scale = nothing, bias = nothing,
188+
ϵ = ofeltype(x, 1e-5)) where {N, S}
189+
@ignore_derivatives if S > N
190+
throw(DimensionMismatch("got $S reduction dims for $N-dimensional array"))
191+
end
192+
μ, σ² = norm_stats(x, ntuple(identity, S))
193+
return norm_helper(x, μ, σ², scale, bias, ϵ, size(x)[1:S])
194+
end
195+
196+
"""
197+
batchnorm(x::AbstractArray{<:Any, N},
198+
running_stats::Union{RunningStats, Nothing} = nothing,
199+
scale::Union{AbstractVector, Nothing} = nothing,
200+
bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5);
201+
training::Bool = within_grad()) where {N}
202+
203+
Functional [Batch Normalization](https://arxiv.org/abs/1502.03167) operation.
204+
205+
Normalizes `x` along each ``D_1×...×D_{N-2}×1×D_N`` input slice,
206+
where `N-1` is the "channel" (or "feature", for 2D inputs) dimension.
207+
208+
Provide a [`RunningStats`](@ref) to fix a estimated mean and variance.
209+
`batchnorm` will renormalize the input using these statistics during inference,
210+
and update them using batch-level statistics when training.
211+
To override this behaviour, manually set a value for `training`.
212+
213+
If specified, `scale` and `bias` will be applied as an additional learned affine transform.
214+
215+
See also [`layernorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref).
216+
"""
217+
function batchnorm(x::AbstractArray{<:Any, N},
218+
running_stats::Union{RunningStats, Nothing} = nothing,
219+
scale::Union{AbstractVector, Nothing} = nothing,
220+
bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5);
221+
training::Bool = within_grad()) where {N}
222+
reduce_dims = ((1:(N - 2))..., N)
223+
μ, σ² = maybe_norm_stats(running_stats, x, reduce_dims, !training)
224+
# Because μ and σ² could be updated in-place, we compute the output first
225+
y = norm_helper(x, μ, σ², scale, bias, ϵ)
226+
@ignore_derivatives if running_stats !== nothing && training
227+
update_running_stats!(running_stats, x, μ, σ², reduce_dims)
228+
end
229+
return y
230+
end
231+
232+
"""
233+
instancenorm(x::AbstractArray{<:Any, N},
234+
running_stats::Union{RunningStats, Nothing} = nothing,
235+
scale::Union{AbstractVector, Nothing} = nothing,
236+
bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5);
237+
training::Bool = within_grad()) where {N}
238+
239+
Functional [Instance Normalization](https://arxiv.org/abs/1607.08022) operation.
240+
241+
Normalizes `x` along each ``D_1×...×D_{N-2}×1×1`` input slice,
242+
243+
Provide a [`RunningStats`](@ref) to fix a estimated mean and variance.
244+
`instancenorm` will renormalize the input using these statistics during inference,
245+
and update them using batch-level statistics when training.
246+
To override this behaviour, manually set a value for `training`.
247+
248+
If specified, `scale` and `bias` will be applied as an additional learned affine transform.
249+
250+
See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`groupnorm`](@ref).
251+
"""
252+
function instancenorm(x::AbstractArray{<:Any, N},
253+
running_stats::Union{RunningStats, Nothing} = nothing,
254+
scale::Union{AbstractVector, Nothing} = nothing,
255+
bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5);
256+
training::Bool = within_grad()) where {N}
257+
affine_size = (ntuple(_ -> 1, N - 2)..., size(x, N - 1), :)
258+
reduce_dims = ((1:(N - 2))...,)
259+
μ, σ² = maybe_norm_stats(running_stats, x, reduce_dims, !training)
260+
# Because μ and σ² could be updated in-place, we compute the output first
261+
y = norm_helper(x, μ, σ², scale, bias, ϵ, affine_size)
262+
ChainRulesCore.@ignore_derivatives if running_stats !== nothing && training
263+
μ′, σ²′ = mean(μ; dims = N), mean(σ²; dims = N) # Need to sum (C, N) -> (C,)
264+
update_running_stats!(running_stats, x, μ′, σ²′, reduce_dims)
265+
end
266+
return y
267+
end
268+
269+
"""
270+
groupnorm(x::AbstractArray{<:Any, N}, groups::Integer,
271+
scale::Union{AbstractVector, Nothing} = nothing,
272+
bias::Union{AbstractVector, Nothing} = nothing,
273+
ϵ = ofeltype(x, 1e-5)) where {N}
274+
275+
Functional [Group Normalization](https://arxiv.org/abs/1803.08494) operation.
276+
277+
Normalizes `x` along the first `N - 2` (spatial) dimensions,
278+
where `N-1` is the "channel" (or "feature", for 2D inputs) dimension,
279+
and the channel dimension is divided into `groups` groups along which statistics are computed.
280+
The number of channels must be an integer multiple of the number of groups.
281+
282+
If specified, `scale` and `bias` will be applied as an additional learned affine transform.
283+
284+
See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`instancenorm`](@ref).
285+
286+
# Examples
287+
288+
```jldoctest
289+
julia> using Statistics
290+
291+
julia> xs = rand(3, 3, 4, 2); # a batch of 2 images, each having 4 channels
292+
293+
julia> y = NNlib.groupnorm(xs, 4);
294+
295+
julia> isapprox(std(y[:, :, 1:2, 1]), 1; atol = 0.1) &&
296+
std(xs[:, :, 1:2, 1]) != std(y[:, :, 1:2, 1])
297+
true
298+
299+
julia> isapprox(std(y[:, :, 3:4, 2]), 1; atol = 0.1) &&
300+
std(xs[:, :, 3:4, 2]) != std(y[:, :, 3:4, 2])
301+
true
302+
```
303+
"""
304+
function groupnorm(x::AbstractArray{<:Any, N}, groups::Integer,
305+
scale::Union{AbstractVector, Nothing} = nothing,
306+
bias::Union{AbstractVector, Nothing} = nothing,
307+
ϵ = ofeltype(x, 1e-5)) where {N}
308+
sz = size(x)
309+
channels = @ignore_derivatives begin
310+
ch = sz[max(1, N - 1)]
311+
newch, remainder = divrem(ch, groups)
312+
remainder == 0 ? newch :
313+
throw(ArgumentError("channels $ch should be multiple of groups $groups"))
314+
end
315+
affine_size = (ntuple(_ -> 1, N - 2)..., channels, groups, :)
316+
grouped_size = (sz[1:(N - 2)]..., channels, groups, :)
317+
x′ = reshape(x, grouped_size)
318+
μ, σ² = norm_stats(x′, ((1:(N - 2))...,))
319+
return reshape(norm_helper(x′, μ, σ², scale, bias, ϵ, affine_size), sz)
320+
end

src/utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,15 @@ if VERSION < v"1.7.0-DEV.793"
162162
end
163163
end
164164

165+
166+
# This is a terrible hack to prevent the spread of type instabilities
167+
# when the pullback type changes depending on runtime information,
168+
# e.g. when a normalization layer is "active" vs "inactive".
169+
function _rrule_pullback_rt(@nospecialize(fn), args...)
170+
rt = Base.promote_op(rrule, typeof(fn), map(typeof, args)...)
171+
rt <: Tuple{<:Any,<:Any} && return rt.parameters[2]
172+
return rt
173+
end
174+
175+
# Extracted from Flux. Should this have a docstring and/or be in the docs?
176+
ofeltype(x, y) = convert(float(eltype(x)), y)

0 commit comments

Comments
 (0)