@@ -18,25 +18,26 @@ BNCache() = BNCache(nothing, nothing)
18
18
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
19
19
# so reshape a 2D Tensor into 4D
20
20
batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T,2} ,
21
- running_mean:: DenseCuArray{T} , running_var:: DenseCuArray{T} , momentum;
22
- cache = nothing , alpha = T (1 ), beta = T (0 ),
23
- eps = T (1e-5 ), training = true ) where T<: Union{Float32, Float64} =
24
- dropdims (batchnorm (g, b, reshape (x, 1 , 1 , size (x, 1 ), size (x, 2 )), running_mean, running_var, momentum,
25
- cache = cache, alpha = alpha, beta = beta, eps = eps, training = training), dims = (1 , 2 ))
21
+ running_mean, running_var, momentum;
22
+ kws... ) where T<: Union{Float32, Float64} =
23
+ dropdims (batchnorm (g, b, reshape (x, 1 , 1 , size (x, 1 ), size (x, 2 )),
24
+ running_mean, running_var, momentum;
25
+ kws... ),
26
+ dims = (1 , 2 ))
26
27
27
28
function batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: Union{DenseCuArray{T,4},DenseCuArray{T,5}} ,
28
- running_mean:: DenseCuArray{T} , running_var:: DenseCuArray{T} , momentum;
29
- cache = nothing , alpha = T (1 ), beta = T (0 ),
30
- eps = T (1e-5 ), training = true ) where T<: Union{Float32, Float64}
31
- cudnnBNForward! (similar (x), g, b, x, running_mean, running_var, momentum, cache = cache,
32
- alpha = alpha, beta = beta, eps = eps, training = training)
29
+ running_mean, running_var, momentum;
30
+ kws... ) where T<: Union{Float32, Float64}
31
+ cudnnBNForward! (similar (x), g, b, x, running_mean, running_var, momentum; kws... )
33
32
end
34
33
35
34
function cudnnBNForward! (y:: DenseCuArray{T} , g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T} ,
36
- running_mean:: DenseCuArray{T} , running_var:: DenseCuArray{T} ,
37
- momentum; cache = nothing ,
35
+ running_mean, running_var, momentum;
36
+ cache = nothing ,
38
37
alpha = T (1 ), beta = T (0 ),
39
- eps = T (1e-5 ), training = true ) where T<: Union{Float32, Float64}
38
+ eps = T (1e-5 ),
39
+ training = true ,
40
+ track_stats = true ) where T<: Union{Float32, Float64}
40
41
dims = _wsize (x)
41
42
if eps < CUDNN_BN_MIN_EPSILON
42
43
# warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", CUDNN_BN_MIN_EPSILON)
@@ -46,8 +47,12 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray
46
47
yd = cudnnTensorDescriptor (y)
47
48
gd = cudnnTensorDescriptor (CUDNN_TENSOR_NCHW, cudnnDataType (T), Cint (length (dims)), dim4 (dims,Val (CUDNN_TENSOR_NCHW)))
48
49
49
- if training
50
+ if ! track_stats
51
+ running_mean = CU_NULL
52
+ running_var = CU_NULL
53
+ end
50
54
55
+ if training
51
56
if cache != = nothing
52
57
mean = zeros (CuArray{T}, dims... )
53
58
ivar = ones (CuArray{T}, dims... )
@@ -69,33 +74,36 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray
69
74
end
70
75
71
76
function ∇batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T, 2} , dy:: DenseCuArray{T, 2} ,
72
- running_mean:: DenseCuArray{T} , running_var:: DenseCuArray{T} , momentum;
73
- cache = nothing , eps = T (1e-5 ), alpha = T (1 ),
74
- beta = T (0 ), training = true ) where T<: Union{Float32, Float64}
77
+ running_mean, running_var, momentum;
78
+ kws... ) where T<: Union{Float32, Float64}
75
79
dg, db, dx = ∇batchnorm (g, b, reshape (x, 1 , 1 , size (x, 1 ), size (x, 2 )), reshape (dy, 1 , 1 , size (dy, 1 ),
76
- size (dy, 2 )), running_mean, running_var, momentum, cache = cache, eps = eps,
77
- alpha = alpha, beta = beta, training = training)
80
+ size (dy, 2 )), running_mean, running_var, momentum; kws... )
78
81
(dg, db, dropdims (dx, dims = (1 , 2 )))
79
82
end
80
83
81
84
function ∇batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T} , dy:: DenseCuArray{T} ,
82
- running_mean:: DenseCuArray{T} , running_var:: DenseCuArray{T} , momentum;
83
- cache = nothing , eps = T (1e-5 ), alpha = T (1 ),
84
- beta = T (0 ), training = true ) where T<: Union{Float32, Float64}
85
+ running_mean, running_var, momentum;
86
+ kws... ) where T<: Union{Float32, Float64}
85
87
dg = similar (g)
86
88
db = similar (b)
87
89
dx = similar (x)
88
- cudnnBNBackward! (dg, g, db, dx, x, dy, running_mean, running_var, T (momentum),
89
- training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
90
+ cudnnBNBackward! (dg, g, db, dx, x, dy, running_mean, running_var, T (momentum); kws... )
90
91
(dg, db, dx)
91
92
end
92
93
93
94
function cudnnBNBackward! (dg:: DenseCuArray{T} , g:: DenseCuArray{T} , db:: DenseCuArray{T} ,
94
95
dx:: DenseCuArray{T} , x:: DenseCuArray{T} , dy:: DenseCuArray{T} ,
95
- running_mean:: DenseCuArray{T} , running_var:: DenseCuArray{T} ,
96
+ running_mean, running_var,
96
97
momentum; cache = nothing , eps = T (1e-5 ),
97
98
alpha = T (1 ), beta = T (0 ),
98
- dalpha = T (1 ), dbeta = T (0 ), training = true ) where T<: Union{Float32, Float64}
99
+ dalpha = T (1 ), dbeta = T (0 ), training = true ,
100
+ track_stats = true ) where T<: Union{Float32, Float64}
101
+
102
+ if ! track_stats
103
+ running_mean = CU_NULL
104
+ running_var = CU_NULL
105
+ end
106
+
99
107
if training
100
108
xd = cudnnTensorDescriptor (x)
101
109
dyd = cudnnTensorDescriptor (dy)
@@ -121,4 +129,4 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
121
129
db .= vec (sum (dy, dims= rdims))
122
130
end
123
131
end
124
-
132
+
0 commit comments