13
13
14
14
BNCache () = BNCache (nothing , nothing )
15
15
16
- @inline _wsize (y) = ( fill ( 1 , ndims (y)- 2 ) ... , size (y)[ end - 1 ], 1 )
16
+ @inline _wsize (y) = ntuple (i -> i == ndims (y) - 1 ? 1 : size (y, i), ndims (y) )
17
17
18
18
function batchnorm (g:: Nothing , b:: Nothing , x:: DenseCuArray ,
19
19
running_mean, running_var, momentum;
20
20
kws... )
21
- g = fill! (similar (x, size (ndims (x)- 1 )), 1 )
22
- b = fill! (similar (x, size (ndims (x)- 1 )), 0 )
21
+ affine_sz = _wsize (x)
22
+ g = fill! (similar (x, affine_sz), 1 )
23
+ b = fill! (similar (x, affine_sz), 0 )
23
24
24
25
batchnorm (g, b, x, running_mean, running_var, momentum;
25
26
kws... )
86
87
87
88
function ∇batchnorm (g:: Nothing , b:: Nothing , x:: DenseCuArray , dy:: DenseCuArray ,
88
89
running_mean, running_var, momentum; kws... )
89
- g = fill! (similar (x, size (ndims (x)- 1 )), 1 )
90
- b = fill! (similar (x, size (ndims (x)- 1 )), 0 )
90
+ affine_sz = _wsize (x)
91
+ g = fill! (similar (x, affine_sz), 1 )
92
+ b = fill! (similar (x, affine_sz), 0 )
91
93
∇batchnorm (g, b, x, dy, running_mean, running_var, momentum; kws... )
92
94
end
93
95
@@ -134,7 +136,7 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
134
136
gd = cudnnTensorDescriptor (CUDNN_TENSOR_NCHW, cudnnDataType (T), Cint (length (_wsize (x))), dim4 (_wsize (x),Val (CUDNN_TENSOR_NCHW)))
135
137
if cache != = nothing
136
138
mean, ivar = cache. mean, cache. ivar
137
- # info( "mean and ivar are fetched from the cache")
139
+ @debug " mean and ivar are fetched from the cache"
138
140
else
139
141
mean, ivar = CU_NULL, CU_NULL
140
142
end
@@ -143,23 +145,17 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
143
145
eps = CUDNN_BN_MIN_EPSILON
144
146
end
145
147
146
- if training
147
- cudnnBatchNormalizationBackward (handle (), CUDNN_BATCHNORM_SPATIAL,
148
- scalingParameter (T, alpha), scalingParameter (T, beta), scalingParameter (T, dalpha), scalingParameter (T, dbeta),
149
- xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps,
150
- mean, ivar)
151
- else
152
- cudnnBatchNormalizationBackward (handle (), CUDNN_BATCHNORM_SPATIAL,
153
- scalingParameter (T, alpha), scalingParameter (T, beta), scalingParameter (T, dalpha), scalingParameter (T, dbeta),
154
- xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps,
155
- running_mean, running_var)
156
- end
148
+ cudnnBatchNormalizationBackward (handle (), CUDNN_BATCHNORM_SPATIAL,
149
+ scalingParameter (T, alpha), scalingParameter (T, beta), scalingParameter (T, dalpha), scalingParameter (T, dbeta),
150
+ xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps,
151
+ mean, ivar)
157
152
end
158
153
159
154
function rrule (:: typeof (batchnorm), g, b, x, running_mean, running_var, momentum; kws... )
160
155
y = batchnorm (g, b, x, running_mean, running_var, momentum; kws... )
161
156
function batchnorm_pullback (Δ)
162
- NoTangent (), ∇batchnorm (g, b, x, Δ, running_mean, running_var, momentum; kws... )... , NoTangent (), NoTangent (), NoTangent ()
157
+ dg, db, dx = ∇batchnorm (g, b, x, Δ, running_mean, running_var, momentum; kws... )
158
+ NoTangent (), something (dg, NoTangent ()), something (db, NoTangent ()), dx, NoTangent (), NoTangent (), NoTangent ()
163
159
end
164
160
y, batchnorm_pullback
165
161
end
0 commit comments