Skip to content

Commit 78b09f5

Browse files
Merge pull request #36 from FluxML/cl/bn
track_stats option for batchnorm
2 parents 6c857ad + b16c745 commit 78b09f5

File tree

2 files changed

+107
-57
lines changed

2 files changed

+107
-57
lines changed

ext/NNlibCUDA/src/cudnn/batchnorm.jl

Lines changed: 84 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,44 +13,65 @@ end
1313

1414
BNCache() = BNCache(nothing, nothing)
1515

16-
@inline _wsize(y) = (fill(1, ndims(y)-2)..., size(y)[end-1], 1)
16+
@inline _wsize(x::AbstractArray{<:Any,N}) where N = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
17+
18+
function batchnorm(g::Nothing, b::Nothing, x::DenseCuArray,
19+
running_mean, running_var, momentum; kws...)
20+
affine_sz = _wsize(x)
21+
g = fill!(similar(x, affine_sz), 1)
22+
b = fill!(similar(x, affine_sz), 0)
23+
return batchnorm(g, b, x, running_mean, running_var, momentum; kws...)
24+
end
1725

1826
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
1927
# so reshape a 2D Tensor into 4D
20-
batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T,2},
21-
running_mean::DenseCuArray{T}, running_var::DenseCuArray{T}, momentum;
22-
cache = nothing, alpha = T(1), beta = T(0),
23-
eps = T(1e-5), training = true) where T<:Union{Float32, Float64} =
24-
dropdims(batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), running_mean, running_var, momentum,
25-
cache = cache, alpha = alpha, beta = beta, eps = eps, training = training), dims = (1, 2))
28+
function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T,2},
29+
running_mean, running_var, momentum; kws...) where T<:CUDNNFloat
30+
x = reshape(x, 1, 1, size(x, 1), size(x, 2))
31+
y = batchnorm(g, b, x, running_mean, running_var, momentum; kws...)
32+
return dropdims(y, dims = (1, 2))
33+
end
2634

2735
function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::Union{DenseCuArray{T,4},DenseCuArray{T,5}},
28-
running_mean::DenseCuArray{T}, running_var::DenseCuArray{T}, momentum;
29-
cache = nothing, alpha = T(1), beta = T(0),
30-
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
31-
cudnnBNForward!(similar(x), g, b, x, running_mean, running_var, momentum, cache = cache,
32-
alpha = alpha, beta = beta, eps = eps, training = training)
36+
running_mean, running_var, momentum; kws...) where T<:CUDNNFloat
37+
cudnnBNForward!(similar(x), g, b, x, running_mean, running_var, momentum; kws...)
3338
end
3439

3540
function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T},
36-
running_mean::DenseCuArray{T}, running_var::DenseCuArray{T},
37-
momentum; cache = nothing,
41+
running_mean, running_var, momentum;
42+
cache = nothing,
3843
alpha = T(1), beta = T(0),
39-
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
44+
eps = T(1e-5),
45+
training = true,
46+
affine = true,
47+
track_stats = true) where T<:CUDNNFloat
4048
dims = _wsize(x)
4149
if eps < CUDNN_BN_MIN_EPSILON
42-
# warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", CUDNN_BN_MIN_EPSILON)
50+
@warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON"
4351
eps = CUDNN_BN_MIN_EPSILON
4452
end
53+
54+
if running_mean === nothing || running_var === nothing
55+
running_mean !== running_var && throw(ArgumentError("both or neither of running_mean and running_var must be nothing"))
56+
if track_stats || !training
57+
running_mean = fill!(similar(x, dims), 0)
58+
running_var = fill!(similar(x, dims), 1)
59+
end
60+
end
61+
4562
xd = cudnnTensorDescriptor(x)
4663
yd = cudnnTensorDescriptor(y)
4764
gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), dim4(dims,Val(CUDNN_TENSOR_NCHW)))
4865

4966
if training
67+
if !track_stats
68+
running_mean = CU_NULL
69+
running_var = CU_NULL
70+
end
5071

5172
if cache !== nothing
52-
mean = zeros(CuArray{T}, dims...)
53-
ivar = ones(CuArray{T}, dims...)
73+
mean = fill!(similar(x, dims), 0)
74+
ivar = fill!(similar(x, dims), 1)
5475
else
5576
mean = CU_NULL
5677
ivar = CU_NULL
@@ -68,57 +89,67 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray
6889
return y
6990
end
7091

92+
function ∇batchnorm(g::Nothing, b::Nothing, x::DenseCuArray, dy::DenseCuArray,
93+
running_mean, running_var, momentum; kws...)
94+
affine_sz = _wsize(x)
95+
g = fill!(similar(x, affine_sz), 1)
96+
b = fill!(similar(x, affine_sz), 0)
97+
return ∇batchnorm(g, b, x, dy, running_mean, running_var, momentum; kws...)
98+
end
99+
71100
function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, dy::DenseCuArray{T, 2},
72-
running_mean::DenseCuArray{T}, running_var::DenseCuArray{T}, momentum;
73-
cache = nothing, eps = T(1e-5), alpha = T(1),
74-
beta = T(0), training = true) where T<:Union{Float32, Float64}
101+
running_mean, running_var, momentum;
102+
kws...) where T<:CUDNNFloat
75103
dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1),
76-
size(dy, 2)), running_mean, running_var, momentum, cache = cache, eps = eps,
77-
alpha = alpha, beta = beta, training = training)
104+
size(dy, 2)), running_mean, running_var, momentum; kws...)
78105
(dg, db, dropdims(dx, dims = (1, 2)))
79106
end
80107

108+
81109
function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
82-
running_mean::DenseCuArray{T}, running_var::DenseCuArray{T}, momentum;
83-
cache = nothing, eps = T(1e-5), alpha = T(1),
84-
beta = T(0), training = true) where T<:Union{Float32, Float64}
110+
running_mean, running_var, momentum;
111+
affine=true, kws...) where T<:CUDNNFloat
85112
dg = similar(g)
86113
db = similar(b)
87114
dx = similar(x)
88-
cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum),
89-
training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
90-
(dg, db, dx)
115+
cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum); kws...)
116+
if affine
117+
(dg, db, dx)
118+
else
119+
# CUDNN always calculates dg and db, therefore we just have to drop them
120+
(nothing, nothing, dx)
121+
end
91122
end
92123

93124
function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuArray{T},
94125
dx::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
95-
running_mean::DenseCuArray{T}, running_var::DenseCuArray{T},
126+
running_mean, running_var,
96127
momentum; cache = nothing, eps = T(1e-5),
97128
alpha = T(1), beta = T(0),
98-
dalpha = T(1), dbeta = T(0), training = true) where T<:Union{Float32, Float64}
99-
if training
100-
xd = cudnnTensorDescriptor(x)
101-
dyd = cudnnTensorDescriptor(dy)
102-
dxd = cudnnTensorDescriptor(dx)
103-
gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), dim4(_wsize(x),Val(CUDNN_TENSOR_NCHW)))
104-
if cache !== nothing
105-
mean, ivar = cache.mean, cache.ivar
106-
info("mean and ivar are fetched from the cache")
107-
else
108-
mean, ivar = CU_NULL, CU_NULL
109-
end
110-
111-
if eps < CUDNN_BN_MIN_EPSILON
112-
eps = CUDNN_BN_MIN_EPSILON
113-
end
129+
dalpha = T(1), dbeta = T(0), training = true,
130+
track_stats = true) where T<:CUDNNFloat
131+
if !track_stats
132+
running_mean = CU_NULL
133+
running_var = CU_NULL
134+
end
114135

115-
cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta), xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
136+
xd = cudnnTensorDescriptor(x)
137+
dyd = cudnnTensorDescriptor(dy)
138+
dxd = cudnnTensorDescriptor(dx)
139+
gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), dim4(_wsize(x),Val(CUDNN_TENSOR_NCHW)))
140+
if cache !== nothing
141+
@debug "fetching mean and ivar from the cache"
142+
mean, ivar = cache.mean, cache.ivar
116143
else
117-
ivar = 1 ./ sqrt.(reshape(running_var, _wsize(x)) .+ eps)
118-
dx .= dy .* reshape(g, _wsize(x)) .* ivar
119-
rdims = ((1:ndims(x)-2)..., ndims(x))
120-
dg .= vec(sum(dy .* (x .- reshape(running_mean, _wsize(x))) .* ivar, dims=rdims))
121-
db .= vec(sum(dy, dims=rdims))
144+
mean, ivar = CU_NULL, CU_NULL
145+
end
146+
147+
if eps < CUDNN_BN_MIN_EPSILON
148+
@warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON"
149+
eps = CUDNN_BN_MIN_EPSILON
122150
end
151+
152+
cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL,
153+
scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta),
154+
xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
123155
end
124-

ext/NNlibCUDA/test/batchnorm.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,27 @@
11
@testset "Batchnorm" begin
22
v = CUDA.rand(Float32, 2)
33
m = CUDA.rand(Float32, 2, 5)
4-
for training in (false, true)
5-
NNlibCUDA.batchnorm(v, v, m, v, v, 1.0; training=training)
6-
NNlibCUDA.∇batchnorm(v, v, m, m, v, v, 1.0; training=training)
7-
end
4+
5+
@testset for training in (true, false), track_stats in (true, false)
6+
kws = (training=training, track_stats=track_stats)
7+
8+
# Normal
9+
NNlibCUDA.batchnorm(v, v, m, v, v, 1.0; kws...)
10+
NNlibCUDA.∇batchnorm(v, v, m, m, v, v, 1.0; kws...)
11+
12+
# No affine
13+
NNlibCUDA.batchnorm(nothing, nothing, m, v, v, 1.0; kws...)
14+
NNlibCUDA.∇batchnorm(nothing, nothing, m, m, v, v, 1.0; kws...)
15+
16+
# No tracking
17+
NNlibCUDA.batchnorm(v, v, m, nothing, nothing, 1.0; kws...)
18+
NNlibCUDA.∇batchnorm(v, v, m, m, nothing, nothing, 1.0; kws...)
19+
20+
# Both or neither tracked or affine params must be set
21+
for (α, β) in ((v, nothing), (nothing, v))
22+
@test_throws MethodError NNlibCUDA.batchnorm(α, β, m, v, v, 1.0; kws...)
23+
@test_throws MethodError NNlibCUDA.∇batchnorm(α, β, m, m, v, v, 1.0; kws...)
24+
@test_throws ArgumentError NNlibCUDA.batchnorm(v, v, m, α, β, 1.0; kws...)
25+
end
26+
end
827
end

0 commit comments

Comments
 (0)