Skip to content

Commit 4221f71

Browse files
track_stats option for batchnorm
1 parent 451901c commit 4221f71

File tree

2 files changed

+42
-27
lines changed

2 files changed

+42
-27
lines changed

ext/NNlibCUDA/src/cudnn/batchnorm.jl

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,26 @@ BNCache() = BNCache(nothing, nothing)
1818
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
1919
# so reshape a 2D Tensor into 4D
2020
batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T,2},
21-
running_mean::DenseCuArray{T}, running_var::DenseCuArray{T}, momentum;
22-
cache = nothing, alpha = T(1), beta = T(0),
23-
eps = T(1e-5), training = true) where T<:Union{Float32, Float64} =
24-
dropdims(batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), running_mean, running_var, momentum,
25-
cache = cache, alpha = alpha, beta = beta, eps = eps, training = training), dims = (1, 2))
21+
running_mean, running_var, momentum;
22+
kws...) where T<:Union{Float32, Float64} =
23+
dropdims(batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)),
24+
running_mean, running_var, momentum;
25+
kws...),
26+
dims = (1, 2))
2627

2728
function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::Union{DenseCuArray{T,4},DenseCuArray{T,5}},
28-
running_mean::DenseCuArray{T}, running_var::DenseCuArray{T}, momentum;
29-
cache = nothing, alpha = T(1), beta = T(0),
30-
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
31-
cudnnBNForward!(similar(x), g, b, x, running_mean, running_var, momentum, cache = cache,
32-
alpha = alpha, beta = beta, eps = eps, training = training)
29+
running_mean, running_var, momentum;
30+
kws...) where T<:Union{Float32, Float64}
31+
cudnnBNForward!(similar(x), g, b, x, running_mean, running_var, momentum; kws...)
3332
end
3433

3534
function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T},
36-
running_mean::DenseCuArray{T}, running_var::DenseCuArray{T},
37-
momentum; cache = nothing,
35+
running_mean, running_var, momentum;
36+
cache = nothing,
3837
alpha = T(1), beta = T(0),
39-
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
38+
eps = T(1e-5),
39+
training = true,
40+
track_stats = true) where T<:Union{Float32, Float64}
4041
dims = _wsize(x)
4142
if eps < CUDNN_BN_MIN_EPSILON
4243
# warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", CUDNN_BN_MIN_EPSILON)
@@ -46,8 +47,12 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray
4647
yd = cudnnTensorDescriptor(y)
4748
gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), dim4(dims,Val(CUDNN_TENSOR_NCHW)))
4849

49-
if training
50+
if !track_stats
51+
running_mean = CU_NULL
52+
running_var = CU_NULL
53+
end
5054

55+
if training
5156
if cache !== nothing
5257
mean = zeros(CuArray{T}, dims...)
5358
ivar = ones(CuArray{T}, dims...)
@@ -69,33 +74,36 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray
6974
end
7075

7176
function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, dy::DenseCuArray{T, 2},
72-
running_mean::DenseCuArray{T}, running_var::DenseCuArray{T}, momentum;
73-
cache = nothing, eps = T(1e-5), alpha = T(1),
74-
beta = T(0), training = true) where T<:Union{Float32, Float64}
77+
running_mean, running_var, momentum;
78+
kws...) where T<:Union{Float32, Float64}
7579
dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1),
76-
size(dy, 2)), running_mean, running_var, momentum, cache = cache, eps = eps,
77-
alpha = alpha, beta = beta, training = training)
80+
size(dy, 2)), running_mean, running_var, momentum; kws...)
7881
(dg, db, dropdims(dx, dims = (1, 2)))
7982
end
8083

8184
function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
82-
running_mean::DenseCuArray{T}, running_var::DenseCuArray{T}, momentum;
83-
cache = nothing, eps = T(1e-5), alpha = T(1),
84-
beta = T(0), training = true) where T<:Union{Float32, Float64}
85+
running_mean, running_var, momentum;
86+
kws...) where T<:Union{Float32, Float64}
8587
dg = similar(g)
8688
db = similar(b)
8789
dx = similar(x)
88-
cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum),
89-
training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
90+
cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum); kws...)
9091
(dg, db, dx)
9192
end
9293

9394
function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuArray{T},
9495
dx::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
95-
running_mean::DenseCuArray{T}, running_var::DenseCuArray{T},
96+
running_mean, running_var,
9697
momentum; cache = nothing, eps = T(1e-5),
9798
alpha = T(1), beta = T(0),
98-
dalpha = T(1), dbeta = T(0), training = true) where T<:Union{Float32, Float64}
99+
dalpha = T(1), dbeta = T(0), training = true,
100+
track_stats = true) where T<:Union{Float32, Float64}
101+
102+
if !track_stats
103+
running_mean = CU_NULL
104+
running_var = CU_NULL
105+
end
106+
99107
if training
100108
xd = cudnnTensorDescriptor(x)
101109
dyd = cudnnTensorDescriptor(dy)
@@ -121,4 +129,4 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
121129
db .= vec(sum(dy, dims=rdims))
122130
end
123131
end
124-
132+

ext/NNlibCUDA/test/batchnorm.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,11 @@
55
NNlibCUDA.batchnorm(v, v, m, v, v, 1.0; training=training)
66
NNlibCUDA.∇batchnorm(v, v, m, m, v, v, 1.0; training=training)
77
end
8+
9+
@testset "track_stats=false" begin
10+
for training in (false, true)
11+
NNlibCUDA.batchnorm(v, v, m, nothing, nothing, 1.0; training=training)
12+
NNlibCUDA.∇batchnorm(v, v, m, m, nothing, nothing, 1.0; training=training)
13+
end
14+
end
815
end

0 commit comments

Comments
 (0)