Skip to content

Commit b3e9682

Browse files
committed
cleanup and tests passing
1 parent 9c32265 commit b3e9682

File tree

2 files changed

+58
-52
lines changed

2 files changed

+58
-52
lines changed

ext/NNlibCUDA/src/cudnn/batchnorm.jl

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,61 +13,65 @@ end
1313

1414
BNCache() = BNCache(nothing, nothing)
1515

16-
@inline _wsize(y) = ntuple(i -> i == ndims(y) - 1 ? 1 : size(y, i), ndims(y))
16+
@inline _wsize(x::AbstractArray{<:Any,N}) where N = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
1717

1818
function batchnorm(g::Nothing, b::Nothing, x::DenseCuArray,
19-
running_mean, running_var, momentum;
20-
kws...)
19+
running_mean, running_var, momentum; kws...)
2120
affine_sz = _wsize(x)
2221
g = fill!(similar(x, affine_sz), 1)
2322
b = fill!(similar(x, affine_sz), 0)
24-
25-
batchnorm(g, b, x, running_mean, running_var, momentum;
26-
kws...)
23+
return batchnorm(g, b, x, running_mean, running_var, momentum; kws...)
2724
end
2825

2926
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
3027
# so reshape a 2D Tensor into 4D
31-
batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T,2},
32-
running_mean, running_var, momentum;
33-
kws...) where T<:Union{Float32, Float64} =
34-
dropdims(batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)),
35-
running_mean, running_var, momentum;
36-
kws...),
37-
dims = (1, 2))
28+
function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T,2},
29+
running_mean, running_var, momentum; kws...) where T<:Union{Float32, Float64}
30+
x = reshape(x, 1, 1, size(x, 1), size(x, 2))
31+
y = batchnorm(g, b, x, running_mean, running_var, momentum; kws...)
32+
return dropdims(y, dims = (1, 2))
33+
end
3834

3935
function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::Union{DenseCuArray{T,4},DenseCuArray{T,5}},
40-
running_mean, running_var, momentum;
41-
kws...) where T<:Union{Float32, Float64}
36+
running_mean, running_var, momentum; kws...) where T<:Union{Float32, Float64}
4237
cudnnBNForward!(similar(x), g, b, x, running_mean, running_var, momentum; kws...)
4338
end
4439

4540
function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T},
46-
running_mean, running_var, momentum;
47-
cache = nothing,
41+
running_mean, running_var, momentum;
42+
cache = nothing,
4843
alpha = T(1), beta = T(0),
49-
eps = T(1e-5),
44+
eps = T(1e-5),
5045
training = true,
5146
affine = true,
5247
track_stats = true) where T<:Union{Float32, Float64}
5348
dims = _wsize(x)
5449
if eps < CUDNN_BN_MIN_EPSILON
55-
# warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", CUDNN_BN_MIN_EPSILON)
50+
@warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON"
5651
eps = CUDNN_BN_MIN_EPSILON
5752
end
53+
54+
if running_mean === nothing || running_var === nothing
55+
running_mean !== running_var && throw(ArgumentError("both or neither of running_mean and running_var must be nothing"))
56+
if track_stats || !training
57+
running_mean = fill!(similar(x, dims), 0)
58+
running_var = fill!(similar(x, dims), 1)
59+
end
60+
end
61+
5862
xd = cudnnTensorDescriptor(x)
5963
yd = cudnnTensorDescriptor(y)
6064
gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), dim4(dims,Val(CUDNN_TENSOR_NCHW)))
6165

62-
6366
if training
6467
if !track_stats
6568
running_mean = CU_NULL
6669
running_var = CU_NULL
6770
end
71+
6872
if cache !== nothing
69-
mean = zeros(CuArray{T}, dims...)
70-
ivar = ones(CuArray{T}, dims...)
73+
mean = fill!(similar(x, dims), 0)
74+
ivar = fill!(similar(x, dims), 1)
7175
else
7276
mean = CU_NULL
7377
ivar = CU_NULL
@@ -86,11 +90,11 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray
8690
end
8791

8892
function ∇batchnorm(g::Nothing, b::Nothing, x::DenseCuArray, dy::DenseCuArray,
89-
running_mean, running_var, momentum; kws...)
93+
running_mean, running_var, momentum; kws...)
9094
affine_sz = _wsize(x)
9195
g = fill!(similar(x, affine_sz), 1)
9296
b = fill!(similar(x, affine_sz), 0)
93-
∇batchnorm(g, b, x, dy, running_mean, running_var, momentum; kws...)
97+
return ∇batchnorm(g, b, x, dy, running_mean, running_var, momentum; kws...)
9498
end
9599

96100
function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, dy::DenseCuArray{T, 2},
@@ -112,7 +116,7 @@ function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}
112116
if affine
113117
(dg, db, dx)
114118
else
115-
# CUDNN always calculates dg and db, therefore we just have to drop them
119+
# CUDNN always calculates dg and db, therefore we just have to drop them
116120
(nothing, nothing, dx)
117121
end
118122
end
@@ -122,9 +126,8 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
122126
running_mean, running_var,
123127
momentum; cache = nothing, eps = T(1e-5),
124128
alpha = T(1), beta = T(0),
125-
dalpha = T(1), dbeta = T(0), training = true,
129+
dalpha = T(1), dbeta = T(0), training = true,
126130
track_stats = true) where T<:Union{Float32, Float64}
127-
128131
if !track_stats
129132
running_mean = CU_NULL
130133
running_var = CU_NULL
@@ -135,27 +138,18 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
135138
dxd = cudnnTensorDescriptor(dx)
136139
gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), dim4(_wsize(x),Val(CUDNN_TENSOR_NCHW)))
137140
if cache !== nothing
141+
@debug "fetching mean and ivar from the cache"
138142
mean, ivar = cache.mean, cache.ivar
139-
@debug "mean and ivar are fetched from the cache"
140143
else
141144
mean, ivar = CU_NULL, CU_NULL
142145
end
143146

144147
if eps < CUDNN_BN_MIN_EPSILON
148+
@warn "eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON"
145149
eps = CUDNN_BN_MIN_EPSILON
146150
end
147151

148-
cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL,
149-
scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta),
150-
xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps,
151-
mean, ivar)
152-
end
153-
154-
function rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kws...)
155-
y = batchnorm(g, b, x, running_mean, running_var, momentum; kws...)
156-
function batchnorm_pullback(Δ)
157-
dg, db, dx = ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kws...)
158-
NoTangent(), something(dg, NoTangent()), something(db, NoTangent()), dx, NoTangent(), NoTangent(), NoTangent()
159-
end
160-
y, batchnorm_pullback
152+
cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL,
153+
scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta),
154+
xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
161155
end

ext/NNlibCUDA/test/batchnorm.jl

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
11
@testset "Batchnorm" begin
22
v = CUDA.rand(Float32, 2)
33
m = CUDA.rand(Float32, 2, 5)
4-
for training in (false, true)
5-
NNlibCUDA.batchnorm(v, v, m, v, v, 1.0; training=training)
6-
NNlibCUDA.∇batchnorm(v, v, m, m, v, v, 1.0; training=training)
7-
end
8-
9-
@testset "track_stats=false" begin
10-
for training in (false, true)
11-
NNlibCUDA.batchnorm(v, v, m, nothing, nothing, 1.0; training=training, track_stats=false)
12-
NNlibCUDA.∇batchnorm(v, v, m, m, nothing, nothing, 1.0; training=training, track_stats=false)
13-
end
14-
end
4+
5+
@testset for training in (true, false), track_stats in (true, false)
6+
kws = (training=training, track_stats=track_stats)
7+
8+
# Normal
9+
NNlibCUDA.batchnorm(v, v, m, v, v, 1.0; kws...)
10+
NNlibCUDA.∇batchnorm(v, v, m, m, v, v, 1.0; kws...)
11+
12+
# No affine
13+
NNlibCUDA.batchnorm(nothing, nothing, m, v, v, 1.0; kws...)
14+
NNlibCUDA.∇batchnorm(nothing, nothing, m, m, v, v, 1.0; kws...)
15+
16+
# No tracking
17+
NNlibCUDA.batchnorm(v, v, m, nothing, nothing, 1.0; kws...)
18+
NNlibCUDA.∇batchnorm(v, v, m, m, nothing, nothing, 1.0; kws...)
19+
20+
# Both or neither tracked or affine params must be set
21+
for (α, β) in ((v, nothing), (nothing, v))
22+
@test_throws MethodError NNlibCUDA.batchnorm(α, β, m, v, v, 1.0; kws...)
23+
@test_throws MethodError NNlibCUDA.∇batchnorm(α, β, m, m, v, v, 1.0; kws...)
24+
@test_throws ArgumentError NNlibCUDA.batchnorm(v, v, m, α, β, 1.0; kws...)
25+
end
26+
end
1527
end

0 commit comments

Comments
 (0)