13
13
14
14
BNCache () = BNCache (nothing , nothing )
15
15
16
- @inline _wsize (y) = ntuple (i -> i == ndims (y) - 1 ? 1 : size (y, i), ndims (y) )
16
+ @inline _wsize (x :: AbstractArray{<:Any,N} ) where N = ntuple (i -> i == N - 1 ? size (x, N - 1 ) : 1 , N )
17
17
18
18
function batchnorm (g:: Nothing , b:: Nothing , x:: DenseCuArray ,
19
- running_mean, running_var, momentum;
20
- kws... )
19
+ running_mean, running_var, momentum; kws... )
21
20
affine_sz = _wsize (x)
22
21
g = fill! (similar (x, affine_sz), 1 )
23
22
b = fill! (similar (x, affine_sz), 0 )
24
-
25
- batchnorm (g, b, x, running_mean, running_var, momentum;
26
- kws... )
23
+ return batchnorm (g, b, x, running_mean, running_var, momentum; kws... )
27
24
end
28
25
29
26
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
30
27
# so reshape a 2D Tensor into 4D
31
- batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T,2} ,
32
- running_mean, running_var, momentum;
33
- kws... ) where T<: Union{Float32, Float64} =
34
- dropdims (batchnorm (g, b, reshape (x, 1 , 1 , size (x, 1 ), size (x, 2 )),
35
- running_mean, running_var, momentum;
36
- kws... ),
37
- dims = (1 , 2 ))
28
+ function batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T,2} ,
29
+ running_mean, running_var, momentum; kws... ) where T<: Union{Float32, Float64}
30
+ x = reshape (x, 1 , 1 , size (x, 1 ), size (x, 2 ))
31
+ y = batchnorm (g, b, x, running_mean, running_var, momentum; kws... )
32
+ return dropdims (y, dims = (1 , 2 ))
33
+ end
38
34
39
35
function batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: Union{DenseCuArray{T,4},DenseCuArray{T,5}} ,
40
- running_mean, running_var, momentum;
41
- kws... ) where T<: Union{Float32, Float64}
36
+ running_mean, running_var, momentum; kws... ) where T<: Union{Float32, Float64}
42
37
cudnnBNForward! (similar (x), g, b, x, running_mean, running_var, momentum; kws... )
43
38
end
44
39
45
40
function cudnnBNForward! (y:: DenseCuArray{T} , g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T} ,
46
- running_mean, running_var, momentum;
47
- cache = nothing ,
41
+ running_mean, running_var, momentum;
42
+ cache = nothing ,
48
43
alpha = T (1 ), beta = T (0 ),
49
- eps = T (1e-5 ),
44
+ eps = T (1e-5 ),
50
45
training = true ,
51
46
affine = true ,
52
47
track_stats = true ) where T<: Union{Float32, Float64}
53
48
dims = _wsize (x)
54
49
if eps < CUDNN_BN_MIN_EPSILON
55
- # warn( "eps ", eps," is too small for CuDNN so eps has been assigned the value ", CUDNN_BN_MIN_EPSILON)
50
+ @ warn " eps $ eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON= $CUDNN_BN_MIN_EPSILON "
56
51
eps = CUDNN_BN_MIN_EPSILON
57
52
end
53
+
54
+ if running_mean === nothing || running_var === nothing
55
+ running_mean != = running_var && throw (ArgumentError (" both or neither of running_mean and running_var must be nothing" ))
56
+ if track_stats || ! training
57
+ running_mean = fill! (similar (x, dims), 0 )
58
+ running_var = fill! (similar (x, dims), 1 )
59
+ end
60
+ end
61
+
58
62
xd = cudnnTensorDescriptor (x)
59
63
yd = cudnnTensorDescriptor (y)
60
64
gd = cudnnTensorDescriptor (CUDNN_TENSOR_NCHW, cudnnDataType (T), Cint (length (dims)), dim4 (dims,Val (CUDNN_TENSOR_NCHW)))
61
65
62
-
63
66
if training
64
67
if ! track_stats
65
68
running_mean = CU_NULL
66
69
running_var = CU_NULL
67
70
end
71
+
68
72
if cache != = nothing
69
- mean = zeros (CuArray{T} , dims... )
70
- ivar = ones (CuArray{T} , dims... )
73
+ mean = fill! ( similar (x , dims), 0 )
74
+ ivar = fill! ( similar (x , dims), 1 )
71
75
else
72
76
mean = CU_NULL
73
77
ivar = CU_NULL
@@ -86,11 +90,11 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray
86
90
end
87
91
88
92
function ∇batchnorm (g:: Nothing , b:: Nothing , x:: DenseCuArray , dy:: DenseCuArray ,
89
- running_mean, running_var, momentum; kws... )
93
+ running_mean, running_var, momentum; kws... )
90
94
affine_sz = _wsize (x)
91
95
g = fill! (similar (x, affine_sz), 1 )
92
96
b = fill! (similar (x, affine_sz), 0 )
93
- ∇batchnorm (g, b, x, dy, running_mean, running_var, momentum; kws... )
97
+ return ∇batchnorm (g, b, x, dy, running_mean, running_var, momentum; kws... )
94
98
end
95
99
96
100
function ∇batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T, 2} , dy:: DenseCuArray{T, 2} ,
@@ -112,7 +116,7 @@ function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}
112
116
if affine
113
117
(dg, db, dx)
114
118
else
115
- # CUDNN always calculates dg and db, therefore we just have to drop them
119
+ # CUDNN always calculates dg and db, therefore we just have to drop them
116
120
(nothing , nothing , dx)
117
121
end
118
122
end
@@ -122,9 +126,8 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
122
126
running_mean, running_var,
123
127
momentum; cache = nothing , eps = T (1e-5 ),
124
128
alpha = T (1 ), beta = T (0 ),
125
- dalpha = T (1 ), dbeta = T (0 ), training = true ,
129
+ dalpha = T (1 ), dbeta = T (0 ), training = true ,
126
130
track_stats = true ) where T<: Union{Float32, Float64}
127
-
128
131
if ! track_stats
129
132
running_mean = CU_NULL
130
133
running_var = CU_NULL
@@ -135,27 +138,18 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
135
138
dxd = cudnnTensorDescriptor (dx)
136
139
gd = cudnnTensorDescriptor (CUDNN_TENSOR_NCHW, cudnnDataType (T), Cint (length (_wsize (x))), dim4 (_wsize (x),Val (CUDNN_TENSOR_NCHW)))
137
140
if cache != = nothing
141
+ @debug " fetching mean and ivar from the cache"
138
142
mean, ivar = cache. mean, cache. ivar
139
- @debug " mean and ivar are fetched from the cache"
140
143
else
141
144
mean, ivar = CU_NULL, CU_NULL
142
145
end
143
146
144
147
if eps < CUDNN_BN_MIN_EPSILON
148
+ @warn " eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON "
145
149
eps = CUDNN_BN_MIN_EPSILON
146
150
end
147
151
148
- cudnnBatchNormalizationBackward (handle (), CUDNN_BATCHNORM_SPATIAL,
149
- scalingParameter (T, alpha), scalingParameter (T, beta), scalingParameter (T, dalpha), scalingParameter (T, dbeta),
150
- xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps,
151
- mean, ivar)
152
- end
153
-
154
- function rrule (:: typeof (batchnorm), g, b, x, running_mean, running_var, momentum; kws... )
155
- y = batchnorm (g, b, x, running_mean, running_var, momentum; kws... )
156
- function batchnorm_pullback (Δ)
157
- dg, db, dx = ∇batchnorm (g, b, x, Δ, running_mean, running_var, momentum; kws... )
158
- NoTangent (), something (dg, NoTangent ()), something (db, NoTangent ()), dx, NoTangent (), NoTangent (), NoTangent ()
159
- end
160
- y, batchnorm_pullback
152
+ cudnnBatchNormalizationBackward (handle (), CUDNN_BATCHNORM_SPATIAL,
153
+ scalingParameter (T, alpha), scalingParameter (T, beta), scalingParameter (T, dalpha), scalingParameter (T, dbeta),
154
+ xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
161
155
end
0 commit comments