Skip to content

Commit a9dc138

Browse files
committed
fix CI on 1.6 and MacOS
1 parent e0e61dd commit a9dc138

File tree

2 files changed

+20
-18
lines changed

2 files changed

+20
-18
lines changed

src/normalization.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ _apply_scale_bias(x, scale, bias) = x .* scale .+ bias
5454
5555
Shared code path for all built-in norm functions.
5656
57-
`μ` and `σ²` should be calculated on the fly using [`NNlib.norm_stats`](@ref),
58-
or extracted from an existing collection such as [`NNlib.RunningStats`](@ref).
57+
`μ` and `σ²` should be calculated on the fly using [`norm_stats`](@ref),
58+
or extracted from an existing collection such as [`RunningStats`](@ref).
5959
`bias` and `scale` are consistent with cuDNN and Flux.Scale.
6060
We opt for `scale` over `weight` to avoid confusion with dense layers.
6161
If the size of the statistics and affine parameters differ,
@@ -79,7 +79,7 @@ Contains running mean and variance estimates for stateful norm functions.
7979
If the parameters are mutable, they will be updated in-place.
8080
Otherwise, they will be replaced wholesale.
8181
82-
See also [`NNlib.update_running_stats!`](@ref).
82+
See also [`update_running_stats!`](@ref).
8383
"""
8484
mutable struct RunningStats{M <: AbstractArray, V <: AbstractArray, MT <: Real}
8585
mean::M
@@ -129,10 +129,10 @@ end
129129
reduce_dims) where {N}
130130
131131
Performs a moving average update for layers with tracked statistics.
132-
`μ` and `σ²` are the sample mean and variance, most likely from [`NNlib.norm_stats`](@ref).
133-
`reduce_dims` should also match the `dims` argument of [`NNlib.norm_stats`](@ref).
132+
`μ` and `σ²` are the sample mean and variance, most likely from [`norm_stats`](@ref).
133+
`reduce_dims` should also match the `dims` argument of [`norm_stats`](@ref).
134134
135-
See also [`NNlib.RunningStats`](@ref).
135+
See also [`RunningStats`](@ref).
136136
"""
137137
function update_running_stats!(stats::RunningStats, x, μ, σ², reduce_dims::Dims)
138138
V = eltype(σ²)
@@ -168,7 +168,7 @@ Normalizes `x` along the first `S` dimensions.
168168
169169
For an additional learned affine transform, provide a `S`-dimensional `scale` and `bias`.
170170
171-
See also [`NNlib.batchnorm`](@ref), [`NNlib.instancenorm`](@ref), and [`NNlib.groupnorm`](@ref).
171+
See also [`batchnorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref).
172172
173173
# Examples
174174
@@ -205,14 +205,14 @@ Functional [Batch Normalization](https://arxiv.org/abs/1502.03167) operation.
205205
Normalizes `x` along each ``D_1×...×D_{N-2}×1×D_N`` input slice,
206206
where `N-1` is the "channel" (or "feature", for 2D inputs) dimension.
207207
208-
Provide a [`NNlib.RunningStats`](@ref) to fix a estimated mean and variance.
208+
Provide a [`RunningStats`](@ref) to fix a estimated mean and variance.
209209
`batchnorm` will renormalize the input using these statistics during inference,
210210
and update them using batch-level statistics when training.
211211
To override this behaviour, manually set a value for `training`.
212212
213213
If specified, `scale` and `bias` will be applied as an additional learned affine transform.
214214
215-
See also [`NNlib.layernorm`](@ref), [`NNlib.instancenorm`](@ref), and [`NNlib.groupnorm`](@ref).
215+
See also [`layernorm`](@ref), [`instancenorm`](@ref), and [`groupnorm`](@ref).
216216
"""
217217
function batchnorm(x::AbstractArray{<:Any, N},
218218
running_stats::Union{RunningStats, Nothing} = nothing,
@@ -247,7 +247,7 @@ To override this behaviour, manually set a value for `training`.
247247
248248
If specified, `scale` and `bias` will be applied as an additional learned affine transform.
249249
250-
See also [`NNlib.layernorm`](@ref), [`NNlib.batchnorm`](@ref), and [`NNlib.groupnorm`](@ref).
250+
See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`groupnorm`](@ref).
251251
"""
252252
function instancenorm(x::AbstractArray{<:Any, N},
253253
running_stats::Union{RunningStats, Nothing} = nothing,
@@ -281,7 +281,7 @@ The number of channels must be an integer multiple of the number of groups.
281281
282282
If specified, `scale` and `bias` will be applied as an additional learned affine transform.
283283
284-
See also [`NNlib.layernorm`](@ref), [`NNlib.batchnorm`](@ref), and [`NNlib.instancenorm`](@ref).
284+
See also [`layernorm`](@ref), [`batchnorm`](@ref), and [`instancenorm`](@ref).
285285
286286
# Examples
287287

test/normalization.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ end
3535

3636
# Group/InstanceNorm dimensions
3737
let W = 128, C = 2, N = 2, shape = (W, W, 1, 1)
38-
x = [randn_sample(shape, 1, 1);;; randn_sample(shape, 2, 2);;;;
39-
randn_sample(shape, 3, 3);;; randn_sample(shape, 4, 4)]
38+
# Tile to W x W x 2 x 2
39+
x = cat(cat(randn_sample(shape, 1, 1), randn_sample(shape, 2, 2); dims = 3),
40+
cat(randn_sample(shape, 3, 3), randn_sample(shape, 4, 4); dims = 3);
41+
dims = 4)
4042
μ, σ² = NNlib.norm_stats(x, (1, 2))
4143
@test vec(μ)1:(C * N) rtol=0.05
4244
@test vec(σ²)abs2.(1:(C * N)) rtol=0.05
@@ -60,7 +62,9 @@ end
6062
(running_stats, true, y_ns, y_ns, dx_ns),
6163
(running_stats, false, meanvar, meanvar, NoTangent()),
6264
]
63-
@test NNlib.maybe_norm_stats(stats, x, dims, !training) == y
65+
= NNlib.maybe_norm_stats(stats, x, dims, !training)
66+
@test ŷ[1]y[1] rtol=1e-5
67+
@test ŷ[2]y[2] rtol=1e-5
6468
ŷ, back = rrule(NNlib.maybe_norm_stats, stats, x, dims, !training)
6569
@test== y_ad
6670
@test back(meanvar) == (NoTangent(), NoTangent(), dx, NoTangent(), NoTangent())
@@ -170,8 +174,7 @@ end
170174
@testset for use_stats in (true, false)
171175
stats = use_stats ? NNlib.RunningStats(zeros(2), ones(2), 0.1) : nothing
172176
y, back = Zygote.pullback(NNlib.instancenorm, x, stats, scale, bias, 1e-5)
173-
@test y[-1.22474 -1.22474; 0.0 0.0; 1.22474 1.22474;;;
174-
-1.22474 -1.22474; 0.0 0.0; 1.22474 1.22474] rtol=1e-5
177+
@test yrepeat([-1.22474, 0.0, 1.22474], 1, 2, 2) rtol=1e-5
175178

176179
expected_mean, expected_var = [0.5, 0.8], [1.0, 1.0]
177180
if use_stats
@@ -197,8 +200,7 @@ end
197200
end
198201

199202
dx, dstats, dscale, dbias, _ = back(fill!(similar(y), 1))
200-
@test dx[3.6742 3.6742; 1.22474 1.22474; -1.22474 -1.22474;;;
201-
3.6742 3.6742; 1.22474 1.22474; -1.22474 -1.22474] rtol=1e-5
203+
@test dxrepeat([3.6742, 1.22474, -1.22474], 1, 2, 2) rtol=1e-5
202204
@test dscale == zeros(2)
203205
@test dbias == fill(6.0, 2)
204206
@test dstats === nothing

0 commit comments

Comments
 (0)