Skip to content

Commit 8718491

Browse files
authored
Merge pull request #2363 from FluxML/bc/norm-ad-ignores
Non-diff shape handling in norm layers
2 parents e9fb65c + 790eb84 commit 8718491

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

src/layers/normalise.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,9 @@ function _norm_layer_forward(
226226
l, x::AbstractArray{T, N}; reduce_dims, affine_shape,
227227
) where {T, N}
228228
if !_isactive(l, x) && l.track_stats # testmode with tracked stats
229-
stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
229+
stats_shape = ChainRulesCore.ignore_derivatives() do
230+
ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
231+
end
230232
μ = reshape(l.μ, stats_shape)
231233
σ² = reshape(l.σ², stats_shape)
232234
else # trainmode or testmode without tracked stats
@@ -347,7 +349,9 @@ trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;)
347349
function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N}
348350
_size_check(BN, x, N-1 => BN.chs)
349351
reduce_dims = [1:N-2; N]
350-
affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
352+
affine_shape = ChainRulesCore.ignore_derivatives() do
353+
ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
354+
end
351355
return _norm_layer_forward(BN, x; reduce_dims, affine_shape)
352356
end
353357

@@ -439,7 +443,9 @@ trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;)
439443
function (l::InstanceNorm)(x::AbstractArray{T,N}) where {T,N}
440444
_size_check(l, x, N-1 => l.chs)
441445
reduce_dims = 1:N-2
442-
affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
446+
affine_shape = ChainRulesCore.ignore_derivatives() do
447+
ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
448+
end
443449
return _norm_layer_forward(l, x; reduce_dims, affine_shape)
444450
end
445451

@@ -456,10 +462,10 @@ end
456462

457463
"""
458464
GroupNorm(channels::Int, G::Int, λ = identity;
459-
initβ = zeros32,
465+
initβ = zeros32,
460466
initγ = ones32,
461-
affine = true,
462-
eps = 1f-5,
467+
affine = true,
468+
eps = 1f-5,
463469
momentum = 0.1f0)
464470
465471
[Group Normalization](https://arxiv.org/abs/1803.08494) layer.
@@ -538,12 +544,14 @@ function GroupNorm(chs::Int, G::Int, λ=identity;
538544
end
539545

540546
function (gn::GroupNorm)(x::AbstractArray)
541-
_size_check(gn, x, ndims(x)-1 => gn.chs)
547+
_size_check(gn, x, ndims(x)-1 => gn.chs)
542548
sz = size(x)
543549
x2 = reshape(x, sz[1:end-2]..., sz[end-1]÷gn.G, gn.G, sz[end])
544550
N = ndims(x2) # == ndims(x)+1
545551
reduce_dims = 1:N-2
546-
affine_shape = ntuple(i -> i (N-1, N-2) ? size(x2, i) : 1, N)
552+
affine_shape = ChainRulesCore.ignore_derivatives() do
553+
ntuple(i -> i (N-1, N-2) ? size(x2, i) : 1, N)
554+
end
547555
x3 = _norm_layer_forward(gn, x2; reduce_dims, affine_shape)
548556
return reshape(x3, sz)
549557
end

0 commit comments

Comments
 (0)