13
13
14
14
BNCache () = BNCache (nothing , nothing )
15
15
16
- @inline _wsize (y) = (fill (1 , ndims (y)- 2 )... , size (y)[end - 1 ], 1 )
16
+ @inline _wsize (x:: AbstractArray{<:Any,N} ) where N = ntuple (i -> i == N- 1 ? size (x, N- 1 ) : 1 , N)
17
+
18
+ function batchnorm (g:: Nothing , b:: Nothing , x:: DenseCuArray ,
19
+ running_mean, running_var, momentum; kws... )
20
+ affine_sz = _wsize (x)
21
+ g = fill! (similar (x, affine_sz), 1 )
22
+ b = fill! (similar (x, affine_sz), 0 )
23
+ return batchnorm (g, b, x, running_mean, running_var, momentum; kws... )
24
+ end
17
25
18
26
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
19
27
# so reshape a 2D Tensor into 4D
20
- batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T,2} ,
21
- running_mean:: DenseCuArray{T} , running_var:: DenseCuArray{T} , momentum;
22
- cache = nothing , alpha = T ( 1 ), beta = T ( 0 ),
23
- eps = T ( 1e-5 ), training = true ) where T <: Union{Float32, Float64} =
24
- dropdims (batchnorm (g, b, reshape (x, 1 , 1 , size (x, 1 ), size (x, 2 )), running_mean, running_var, momentum,
25
- cache = cache, alpha = alpha, beta = beta, eps = eps, training = training), dims = ( 1 , 2 ))
28
+ function batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T,2} ,
29
+ running_mean, running_var, momentum; kws ... ) where T <: CUDNNFloat
30
+ x = reshape (x, 1 , 1 , size (x, 1 ), size (x, 2 ))
31
+ y = batchnorm (g, b, x, running_mean, running_var, momentum; kws ... )
32
+ return dropdims (y, dims = ( 1 , 2 ))
33
+ end
26
34
27
35
function batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: Union{DenseCuArray{T,4},DenseCuArray{T,5}} ,
28
- running_mean:: DenseCuArray{T} , running_var:: DenseCuArray{T} , momentum;
29
- cache = nothing , alpha = T (1 ), beta = T (0 ),
30
- eps = T (1e-5 ), training = true ) where T<: Union{Float32, Float64}
31
- cudnnBNForward! (similar (x), g, b, x, running_mean, running_var, momentum, cache = cache,
32
- alpha = alpha, beta = beta, eps = eps, training = training)
36
+ running_mean, running_var, momentum; kws... ) where T<: CUDNNFloat
37
+ cudnnBNForward! (similar (x), g, b, x, running_mean, running_var, momentum; kws... )
33
38
end
34
39
35
40
function cudnnBNForward! (y:: DenseCuArray{T} , g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T} ,
36
- running_mean:: DenseCuArray{T} , running_var:: DenseCuArray{T} ,
37
- momentum; cache = nothing ,
41
+ running_mean, running_var, momentum;
42
+ cache = nothing ,
38
43
alpha = T (1 ), beta = T (0 ),
39
- eps = T (1e-5 ), training = true ) where T<: Union{Float32, Float64}
44
+ eps = T (1e-5 ),
45
+ training = true ,
46
+ affine = true ,
47
+ track_stats = true ) where T<: CUDNNFloat
40
48
dims = _wsize (x)
41
49
if eps < CUDNN_BN_MIN_EPSILON
42
- # warn( "eps ", eps," is too small for CuDNN so eps has been assigned the value ", CUDNN_BN_MIN_EPSILON)
50
+ @ warn " eps $ eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON= $CUDNN_BN_MIN_EPSILON "
43
51
eps = CUDNN_BN_MIN_EPSILON
44
52
end
53
+
54
+ if running_mean === nothing || running_var === nothing
55
+ running_mean != = running_var && throw (ArgumentError (" both or neither of running_mean and running_var must be nothing" ))
56
+ if track_stats || ! training
57
+ running_mean = fill! (similar (x, dims), 0 )
58
+ running_var = fill! (similar (x, dims), 1 )
59
+ end
60
+ end
61
+
45
62
xd = cudnnTensorDescriptor (x)
46
63
yd = cudnnTensorDescriptor (y)
47
64
gd = cudnnTensorDescriptor (CUDNN_TENSOR_NCHW, cudnnDataType (T), Cint (length (dims)), dim4 (dims,Val (CUDNN_TENSOR_NCHW)))
48
65
49
66
if training
67
+ if ! track_stats
68
+ running_mean = CU_NULL
69
+ running_var = CU_NULL
70
+ end
50
71
51
72
if cache != = nothing
52
- mean = zeros (CuArray{T} , dims... )
53
- ivar = ones (CuArray{T} , dims... )
73
+ mean = fill! ( similar (x , dims), 0 )
74
+ ivar = fill! ( similar (x , dims), 1 )
54
75
else
55
76
mean = CU_NULL
56
77
ivar = CU_NULL
@@ -68,57 +89,67 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray
68
89
return y
69
90
end
70
91
92
+ function ∇batchnorm (g:: Nothing , b:: Nothing , x:: DenseCuArray , dy:: DenseCuArray ,
93
+ running_mean, running_var, momentum; kws... )
94
+ affine_sz = _wsize (x)
95
+ g = fill! (similar (x, affine_sz), 1 )
96
+ b = fill! (similar (x, affine_sz), 0 )
97
+ return ∇batchnorm (g, b, x, dy, running_mean, running_var, momentum; kws... )
98
+ end
99
+
71
100
function ∇batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T, 2} , dy:: DenseCuArray{T, 2} ,
72
- running_mean:: DenseCuArray{T} , running_var:: DenseCuArray{T} , momentum;
73
- cache = nothing , eps = T (1e-5 ), alpha = T (1 ),
74
- beta = T (0 ), training = true ) where T<: Union{Float32, Float64}
101
+ running_mean, running_var, momentum;
102
+ kws... ) where T<: CUDNNFloat
75
103
dg, db, dx = ∇batchnorm (g, b, reshape (x, 1 , 1 , size (x, 1 ), size (x, 2 )), reshape (dy, 1 , 1 , size (dy, 1 ),
76
- size (dy, 2 )), running_mean, running_var, momentum, cache = cache, eps = eps,
77
- alpha = alpha, beta = beta, training = training)
104
+ size (dy, 2 )), running_mean, running_var, momentum; kws... )
78
105
(dg, db, dropdims (dx, dims = (1 , 2 )))
79
106
end
80
107
108
+
81
109
function ∇batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T} , dy:: DenseCuArray{T} ,
82
- running_mean:: DenseCuArray{T} , running_var:: DenseCuArray{T} , momentum;
83
- cache = nothing , eps = T (1e-5 ), alpha = T (1 ),
84
- beta = T (0 ), training = true ) where T<: Union{Float32, Float64}
110
+ running_mean, running_var, momentum;
111
+ affine= true , kws... ) where T<: CUDNNFloat
85
112
dg = similar (g)
86
113
db = similar (b)
87
114
dx = similar (x)
88
- cudnnBNBackward! (dg, g, db, dx, x, dy, running_mean, running_var, T (momentum),
89
- training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
90
- (dg, db, dx)
115
+ cudnnBNBackward! (dg, g, db, dx, x, dy, running_mean, running_var, T (momentum); kws... )
116
+ if affine
117
+ (dg, db, dx)
118
+ else
119
+ # CUDNN always calculates dg and db, therefore we just have to drop them
120
+ (nothing , nothing , dx)
121
+ end
91
122
end
92
123
93
124
function cudnnBNBackward! (dg:: DenseCuArray{T} , g:: DenseCuArray{T} , db:: DenseCuArray{T} ,
94
125
dx:: DenseCuArray{T} , x:: DenseCuArray{T} , dy:: DenseCuArray{T} ,
95
- running_mean:: DenseCuArray{T} , running_var:: DenseCuArray{T} ,
126
+ running_mean, running_var,
96
127
momentum; cache = nothing , eps = T (1e-5 ),
97
128
alpha = T (1 ), beta = T (0 ),
98
- dalpha = T (1 ), dbeta = T (0 ), training = true ) where T<: Union{Float32, Float64}
99
- if training
100
- xd = cudnnTensorDescriptor (x)
101
- dyd = cudnnTensorDescriptor (dy)
102
- dxd = cudnnTensorDescriptor (dx)
103
- gd = cudnnTensorDescriptor (CUDNN_TENSOR_NCHW, cudnnDataType (T), Cint (length (_wsize (x))), dim4 (_wsize (x),Val (CUDNN_TENSOR_NCHW)))
104
- if cache != = nothing
105
- mean, ivar = cache. mean, cache. ivar
106
- info (" mean and ivar are fetched from the cache" )
107
- else
108
- mean, ivar = CU_NULL, CU_NULL
109
- end
110
-
111
- if eps < CUDNN_BN_MIN_EPSILON
112
- eps = CUDNN_BN_MIN_EPSILON
113
- end
129
+ dalpha = T (1 ), dbeta = T (0 ), training = true ,
130
+ track_stats = true ) where T<: CUDNNFloat
131
+ if ! track_stats
132
+ running_mean = CU_NULL
133
+ running_var = CU_NULL
134
+ end
114
135
115
- cudnnBatchNormalizationBackward (handle (), CUDNN_BATCHNORM_SPATIAL, scalingParameter (T, alpha), scalingParameter (T, beta), scalingParameter (T, dalpha), scalingParameter (T, dbeta), xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
136
+ xd = cudnnTensorDescriptor (x)
137
+ dyd = cudnnTensorDescriptor (dy)
138
+ dxd = cudnnTensorDescriptor (dx)
139
+ gd = cudnnTensorDescriptor (CUDNN_TENSOR_NCHW, cudnnDataType (T), Cint (length (_wsize (x))), dim4 (_wsize (x),Val (CUDNN_TENSOR_NCHW)))
140
+ if cache != = nothing
141
+ @debug " fetching mean and ivar from the cache"
142
+ mean, ivar = cache. mean, cache. ivar
116
143
else
117
- ivar = 1 ./ sqrt .(reshape (running_var, _wsize (x)) .+ eps)
118
- dx .= dy .* reshape (g, _wsize (x)) .* ivar
119
- rdims = ((1 : ndims (x)- 2 ). .. , ndims (x))
120
- dg .= vec (sum (dy .* (x .- reshape (running_mean, _wsize (x))) .* ivar, dims= rdims))
121
- db .= vec (sum (dy, dims= rdims))
144
+ mean, ivar = CU_NULL, CU_NULL
145
+ end
146
+
147
+ if eps < CUDNN_BN_MIN_EPSILON
148
+ @warn " eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON "
149
+ eps = CUDNN_BN_MIN_EPSILON
122
150
end
151
+
152
+ cudnnBatchNormalizationBackward (handle (), CUDNN_BATCHNORM_SPATIAL,
153
+ scalingParameter (T, alpha), scalingParameter (T, beta), scalingParameter (T, dalpha), scalingParameter (T, dbeta),
154
+ xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
123
155
end
124
-
0 commit comments