@@ -226,7 +226,9 @@ function _norm_layer_forward(
226
226
l, x:: AbstractArray{T, N} ; reduce_dims, affine_shape,
227
227
) where {T, N}
228
228
if ! _isactive (l, x) && l. track_stats # testmode with tracked stats
229
- stats_shape = ntuple (i -> i == N- 1 ? size (x, N- 1 ) : 1 , N)
229
+ stats_shape = ChainRulesCore. ignore_derivatives () do
230
+ ntuple (i -> i == N- 1 ? size (x, N- 1 ) : 1 , N)
231
+ end
230
232
μ = reshape (l. μ, stats_shape)
231
233
σ² = reshape (l. σ², stats_shape)
232
234
else # trainmode or testmode without tracked stats
@@ -347,7 +349,9 @@ trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;)
347
349
function (BN:: BatchNorm )(x:: AbstractArray{T,N} ) where {T,N}
348
350
_size_check (BN, x, N- 1 => BN. chs)
349
351
reduce_dims = [1 : N- 2 ; N]
350
- affine_shape = ntuple (i -> i == N- 1 ? size (x, N- 1 ) : 1 , N)
352
+ affine_shape = ChainRulesCore. ignore_derivatives () do
353
+ ntuple (i -> i == N- 1 ? size (x, N- 1 ) : 1 , N)
354
+ end
351
355
return _norm_layer_forward (BN, x; reduce_dims, affine_shape)
352
356
end
353
357
@@ -439,7 +443,9 @@ trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;)
439
443
function (l:: InstanceNorm )(x:: AbstractArray{T,N} ) where {T,N}
440
444
_size_check (l, x, N- 1 => l. chs)
441
445
reduce_dims = 1 : N- 2
442
- affine_shape = ntuple (i -> i == N- 1 ? size (x, N- 1 ) : 1 , N)
446
+ affine_shape = ChainRulesCore. ignore_derivatives () do
447
+ ntuple (i -> i == N- 1 ? size (x, N- 1 ) : 1 , N)
448
+ end
443
449
return _norm_layer_forward (l, x; reduce_dims, affine_shape)
444
450
end
445
451
@@ -456,10 +462,10 @@ end
456
462
457
463
"""
458
464
GroupNorm(channels::Int, G::Int, λ = identity;
459
- initβ = zeros32,
465
+ initβ = zeros32,
460
466
initγ = ones32,
461
- affine = true,
462
- eps = 1f-5,
467
+ affine = true,
468
+ eps = 1f-5,
463
469
momentum = 0.1f0)
464
470
465
471
[Group Normalization](https://arxiv.org/abs/1803.08494) layer.
@@ -538,12 +544,14 @@ function GroupNorm(chs::Int, G::Int, λ=identity;
538
544
end
539
545
540
546
function (gn:: GroupNorm )(x:: AbstractArray )
541
- _size_check (gn, x, ndims (x)- 1 => gn. chs)
547
+ _size_check (gn, x, ndims (x)- 1 => gn. chs)
542
548
sz = size (x)
543
549
x2 = reshape (x, sz[1 : end - 2 ]. .. , sz[end - 1 ]÷ gn. G, gn. G, sz[end ])
544
550
N = ndims (x2) # == ndims(x)+1
545
551
reduce_dims = 1 : N- 2
546
- affine_shape = ntuple (i -> i ∈ (N- 1 , N- 2 ) ? size (x2, i) : 1 , N)
552
+ affine_shape = ChainRulesCore. ignore_derivatives () do
553
+ ntuple (i -> i ∈ (N- 1 , N- 2 ) ? size (x2, i) : 1 , N)
554
+ end
547
555
x3 = _norm_layer_forward (gn, x2; reduce_dims, affine_shape)
548
556
return reshape (x3, sz)
549
557
end
0 commit comments