@@ -12,3 +12,309 @@ function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, runnin
12
12
end
13
13
y, batchnorm_pullback
14
14
end
15
+
16
+ """
17
+ norm_stats(x, dims)
18
+
19
+ Calculates sample mean and (uncorrected) variance of `x` along `dims`.
20
+
21
+ - `dims=(1,...,N-2,N)` for BatchNorm
22
+ - `dims=(1,...,N-2)` for InstanceNorm and GroupNorm
23
+ - `dims=(1,...,S)` where S < N for LayerNorm/Flux.jl/stable/
24
+
25
+ This is more efficient than calling `mean(x; dims)` and `var(x; dims)` separately,
26
+ because it can share some computation across both.
27
+ Implementors may want to overload this function to use custom kernels and more.
28
+ """
29
+ function norm_stats (x, dims)
30
+ μ = mean (x; dims)
31
+ σ² = var (x; dims, mean = μ, corrected = false )
32
+ return μ, σ²
33
+ end
34
+
35
+ function rrule (:: typeof (norm_stats), x, dims)
36
+ μ, mean_pullback = rrule (mean, x; dims)
37
+ σ², var_pullback = rrule (var, x; dims, mean = μ, corrected = false )
38
+ function norm_stats_pullback (dargs)
39
+ dμ, dσ² = unthunk (dargs)
40
+ dx = ChainRulesCore. add!! (var_pullback (dμ)[2 ], mean_pullback (dσ²)[2 ])
41
+ return (NoTangent (), dx, NoTangent ())
42
+ end
43
+ return (μ, σ²), norm_stats_pullback
44
+ end
45
+
46
+ _maybe_reshape (:: Nothing , _) = nothing
47
+ _maybe_reshape (x, dims) = reshape (x, dims)
48
+ _apply_scale_bias (x, :: Nothing , :: Nothing ) = x
49
+ _apply_scale_bias (x, scale, bias) = x .* scale .+ bias
50
+
51
+ """
52
+ norm_helper(x, μ, σ², scale::Union{AbstractArray, Nothing},
53
+ bias::Union{AbstractArray, Nothing}, ϵ::Real, affine_size = size(μ))
54
+
55
+ Shared code path for all built-in norm functions.
56
+
57
+ `μ` and `σ²` should be calculated on the fly using [`norm_stats`](@ref),
58
+ or extracted from an existing collection such as [`RunningStats`](@ref).
59
+ `bias` and `scale` are consistent with cuDNN and Flux.Scale.
60
+ We opt for `scale` over `weight` to avoid confusion with dense layers.
61
+ If the size of the statistics and affine parameters differ,
62
+ use `affine_size` to add padding dimensions as required to match the input.
63
+ """
64
+ function norm_helper (x, μ, σ², scale:: Union{AbstractArray, Nothing} ,
65
+ bias:: Union{AbstractArray, Nothing} , ϵ:: Real , affine_size = size (μ))
66
+ @ignore_derivatives if isnothing (scale) != isnothing (bias)
67
+ error (" both scale and bias must be provided or left as nothing" )
68
+ end
69
+ scale′, bias′ = _maybe_reshape (scale, affine_size), _maybe_reshape (bias, affine_size)
70
+ return _apply_scale_bias ((x .- μ) ./ sqrt .(σ² .+ ϵ), scale′, bias′)
71
+ end
72
+
73
+ """
74
+ RunningStats(mean, variance, momentum)
75
+
76
+ Contains running mean and variance estimates for stateful norm functions.
77
+ `momentum` controls the strength of the moving average update.
78
+
79
+ If the parameters are mutable, they will be updated in-place.
80
+ Otherwise, they will be replaced wholesale.
81
+
82
+ See also [`update_running_stats!`](@ref).
83
+ """
84
+ mutable struct RunningStats{M <: AbstractArray , V <: AbstractArray , MT <: Real }
85
+ mean:: M
86
+ variance:: V
87
+ momentum:: MT
88
+ end
89
+
90
+ # Conditionally pulls running stats or calculates them on the fly.
91
+ # Part of the reason this is a dedicated function is to have a more type stable pullback.
92
+ function maybe_norm_stats (stats:: Union{RunningStats, Nothing} , x, dims,
93
+ use_running_stats:: Bool )
94
+ if stats != = nothing && use_running_stats
95
+ # Maintains consistency with mean/var
96
+ sz = Base. setindex (Base. reduced_indices (x, dims) |> Base. to_shape, :, ndims (x) - 1 )
97
+ return reshape (stats. mean, sz), reshape (stats. variance, sz)
98
+ end
99
+ # No running stats exist or are disabled in inference mode
100
+ return norm_stats (x, dims)
101
+ end
102
+
103
+ # Kludge so we can close over a Union inner pullback type
104
+ struct MaybeNormStatsPullback{B, P <: ProjectTo{AbstractArray} }
105
+ back:: B
106
+ projector:: P
107
+ end
108
+ function (pb:: MaybeNormStatsPullback )(dargs)
109
+ _, dx = unthunk (pb. back (dargs))
110
+ return (NoTangent (), NoTangent (), pb. projector (dx), NoTangent (), NoTangent ())
111
+ end
112
+ function rrule (:: typeof (maybe_norm_stats), stats:: Union{RunningStats, Nothing} , x, dims,
113
+ use_running_stats:: Bool )
114
+ project = ProjectTo (x)
115
+ noop_back (_) = (NoTangent (), NoTangent ())
116
+ if stats === nothing || ! use_running_stats
117
+ (μ, σ²), back = rrule (norm_stats, x, dims)
118
+ else
119
+ # The default is to track, so this only happens when a layer is frozen
120
+ sz = Base. setindex (Base. reduced_indices (x, dims) |> Base. to_shape, :, ndims (x) - 1 )
121
+ μ, σ², back = reshape (stats. mean, sz), reshape (stats. variance, sz), noop_back
122
+ end
123
+ back_type = Union{typeof (noop_back), _rrule_pullback_rt (norm_stats, x, dims)}
124
+ return (μ, σ²), MaybeNormStatsPullback {back_type, typeof(project)} (back, project)
125
+ end
126
+
127
+ """
128
+ update_running_stats!(stats::RunningStats, x::AbstractArray{<:Any, N}, μ, σ²,
129
+ reduce_dims) where {N}
130
+
131
+ Performs a moving average update for layers with tracked statistics.
132
+ `μ` and `σ²` are the sample mean and variance, most likely from [`norm_stats`](@ref).
133
+ `reduce_dims` should also match the `dims` argument of [`norm_stats`](@ref).
134
+
135
+ See also [`RunningStats`](@ref).
136
+ """
137
+ function update_running_stats! (stats:: RunningStats , x, μ, σ², reduce_dims:: Dims )
138
+ V = eltype (σ²)
139
+ momentum = stats. momentum
140
+ res_mtm = one (V) - momentum
141
+ m = prod (size (x, i) for i in reduce_dims)
142
+ correction = m / (m - one (V))
143
+
144
+ running_mean, running_var = stats. mean, stats. variance
145
+ if ChainRulesCore. is_inplaceable_destination (running_mean)
146
+ stats. mean .= res_mtm .* running_mean .+ momentum .* vec (μ)
147
+ else
148
+ stats. mean = res_mtm .* running_mean .+ momentum .* vec (μ)
149
+ end
150
+ if ChainRulesCore. is_inplaceable_destination (running_var)
151
+ stats. variance .= res_mtm .* running_var .+ momentum .* correction .* vec (σ²)
152
+ else
153
+ stats. variance = res_mtm .* running_var .+ momentum .* correction .* vec (σ²)
154
+ end
155
+ end
156
+
157
+ # Convenience functions
158
+ # We follow roughly the same arg order as torch.nn.functional.*_norm:
159
+ # input, unique args for this particular norm type, bias + scale, eps; kwargs...
160
+
161
+ """
162
+ layernorm(x::AbstractArray{<:Any,N}, ::Val{S}, scale = nothing, bias = nothing,
163
+ ϵ=ofeltype(x, 1e-5)) where {N, S}
164
+
165
+ Functional [Layer Normalization](https://arxiv.org/abs/1607.06450) operation.
166
+
167
+ Normalizes `x` along the first `S` dimensions.
168
+
169
+ For an additional learned affine transform, provide a `S`-dimensional `scale` and `bias`.
170
+
171
+ See also [`batchnorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref).
172
+
173
+ # Examples
174
+
175
+ ```jldoctest
176
+ julia> using Statistics
177
+
178
+ julia> xs = rand(3, 3, 3, 2); # a batch of 2 images, each having 3 channels
179
+
180
+ julia> y = NNlib.layernorm(xs, Val(3));
181
+
182
+ julia> isapprox(std(y; dims = 1:3), ones(1, 1, 1, 2); atol = 0.1) &&
183
+ std(y; dims = 1:3) != std(xs; dims = 1:3)
184
+ true
185
+ ```
186
+ """
187
+ function layernorm (x:: AbstractArray{<:Any, N} , :: Val{S} , scale = nothing , bias = nothing ,
188
+ ϵ = ofeltype (x, 1e-5 )) where {N, S}
189
+ @ignore_derivatives if S > N
190
+ throw (DimensionMismatch (" got $S reduction dims for $N -dimensional array" ))
191
+ end
192
+ μ, σ² = norm_stats (x, ntuple (identity, S))
193
+ return norm_helper (x, μ, σ², scale, bias, ϵ, size (x)[1 : S])
194
+ end
195
+
196
+ """
197
+ batchnorm(x::AbstractArray{<:Any, N},
198
+ running_stats::Union{RunningStats, Nothing} = nothing,
199
+ scale::Union{AbstractVector, Nothing} = nothing,
200
+ bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5);
201
+ training::Bool = within_grad()) where {N}
202
+
203
+ Functional [Batch Normalization](https://arxiv.org/abs/1502.03167) operation.
204
+
205
+ Normalizes `x` along each ``D_1×...×D_{N-2}×1×D_N`` input slice,
206
+ where `N-1` is the "channel" (or "feature", for 2D inputs) dimension.
207
+
208
+ Provide a [`RunningStats`](@ref) to fix a estimated mean and variance.
209
+ `batchnorm` will renormalize the input using these statistics during inference,
210
+ and update them using batch-level statistics when training.
211
+ To override this behaviour, manually set a value for `training`.
212
+
213
+ If specified, `scale` and `bias` will be applied as an additional learned affine transform.
214
+
215
+ See also [`layernorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref).
216
+ """
217
+ function batchnorm (x:: AbstractArray{<:Any, N} ,
218
+ running_stats:: Union{RunningStats, Nothing} = nothing ,
219
+ scale:: Union{AbstractVector, Nothing} = nothing ,
220
+ bias:: Union{AbstractVector, Nothing} = nothing , ϵ = ofeltype (x, 1e-5 );
221
+ training:: Bool = within_grad ()) where {N}
222
+ reduce_dims = ((1 : (N - 2 )). .. , N)
223
+ μ, σ² = maybe_norm_stats (running_stats, x, reduce_dims, ! training)
224
+ # Because μ and σ² could be updated in-place, we compute the output first
225
+ y = norm_helper (x, μ, σ², scale, bias, ϵ)
226
+ @ignore_derivatives if running_stats != = nothing && training
227
+ update_running_stats! (running_stats, x, μ, σ², reduce_dims)
228
+ end
229
+ return y
230
+ end
231
+
232
+ """
233
+ instancenorm(x::AbstractArray{<:Any, N},
234
+ running_stats::Union{RunningStats, Nothing} = nothing,
235
+ scale::Union{AbstractVector, Nothing} = nothing,
236
+ bias::Union{AbstractVector, Nothing} = nothing, ϵ = ofeltype(x, 1e-5);
237
+ training::Bool = within_grad()) where {N}
238
+
239
+ Functional [Instance Normalization](https://arxiv.org/abs/1607.08022) operation.
240
+
241
+ Normalizes `x` along each ``D_1×...×D_{N-2}×1×1`` input slice,
242
+
243
+ Provide a [`RunningStats`](@ref) to fix a estimated mean and variance.
244
+ `instancenorm` will renormalize the input using these statistics during inference,
245
+ and update them using batch-level statistics when training.
246
+ To override this behaviour, manually set a value for `training`.
247
+
248
+ If specified, `scale` and `bias` will be applied as an additional learned affine transform.
249
+
250
+ See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`groupnorm`](@ref).
251
+ """
252
+ function instancenorm (x:: AbstractArray{<:Any, N} ,
253
+ running_stats:: Union{RunningStats, Nothing} = nothing ,
254
+ scale:: Union{AbstractVector, Nothing} = nothing ,
255
+ bias:: Union{AbstractVector, Nothing} = nothing , ϵ = ofeltype (x, 1e-5 );
256
+ training:: Bool = within_grad ()) where {N}
257
+ affine_size = (ntuple (_ -> 1 , N - 2 )... , size (x, N - 1 ), :)
258
+ reduce_dims = ((1 : (N - 2 )). .. ,)
259
+ μ, σ² = maybe_norm_stats (running_stats, x, reduce_dims, ! training)
260
+ # Because μ and σ² could be updated in-place, we compute the output first
261
+ y = norm_helper (x, μ, σ², scale, bias, ϵ, affine_size)
262
+ ChainRulesCore. @ignore_derivatives if running_stats != = nothing && training
263
+ μ′, σ²′ = mean (μ; dims = N), mean (σ²; dims = N) # Need to sum (C, N) -> (C,)
264
+ update_running_stats! (running_stats, x, μ′, σ²′, reduce_dims)
265
+ end
266
+ return y
267
+ end
268
+
269
+ """
270
+ groupnorm(x::AbstractArray{<:Any, N}, groups::Integer,
271
+ scale::Union{AbstractVector, Nothing} = nothing,
272
+ bias::Union{AbstractVector, Nothing} = nothing,
273
+ ϵ = ofeltype(x, 1e-5)) where {N}
274
+
275
+ Functional [Group Normalization](https://arxiv.org/abs/1803.08494) operation.
276
+
277
+ Normalizes `x` along the first `N - 2` (spatial) dimensions,
278
+ where `N-1` is the "channel" (or "feature", for 2D inputs) dimension,
279
+ and the channel dimension is divided into `groups` groups along which statistics are computed.
280
+ The number of channels must be an integer multiple of the number of groups.
281
+
282
+ If specified, `scale` and `bias` will be applied as an additional learned affine transform.
283
+
284
+ See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`instancenorm`](@ref).
285
+
286
+ # Examples
287
+
288
+ ```jldoctest
289
+ julia> using Statistics
290
+
291
+ julia> xs = rand(3, 3, 4, 2); # a batch of 2 images, each having 4 channels
292
+
293
+ julia> y = NNlib.groupnorm(xs, 4);
294
+
295
+ julia> isapprox(std(y[:, :, 1:2, 1]), 1; atol = 0.1) &&
296
+ std(xs[:, :, 1:2, 1]) != std(y[:, :, 1:2, 1])
297
+ true
298
+
299
+ julia> isapprox(std(y[:, :, 3:4, 2]), 1; atol = 0.1) &&
300
+ std(xs[:, :, 3:4, 2]) != std(y[:, :, 3:4, 2])
301
+ true
302
+ ```
303
+ """
304
+ function groupnorm (x:: AbstractArray{<:Any, N} , groups:: Integer ,
305
+ scale:: Union{AbstractVector, Nothing} = nothing ,
306
+ bias:: Union{AbstractVector, Nothing} = nothing ,
307
+ ϵ = ofeltype (x, 1e-5 )) where {N}
308
+ sz = size (x)
309
+ channels = @ignore_derivatives begin
310
+ ch = sz[max (1 , N - 1 )]
311
+ newch, remainder = divrem (ch, groups)
312
+ remainder == 0 ? newch :
313
+ throw (ArgumentError (" channels $ch should be multiple of groups $groups " ))
314
+ end
315
+ affine_size = (ntuple (_ -> 1 , N - 2 )... , channels, groups, :)
316
+ grouped_size = (sz[1 : (N - 2 )]. .. , channels, groups, :)
317
+ x′ = reshape (x, grouped_size)
318
+ μ, σ² = norm_stats (x′, ((1 : (N - 2 )). .. ,))
319
+ return reshape (norm_helper (x′, μ, σ², scale, bias, ϵ, affine_size), sz)
320
+ end
0 commit comments