@@ -15,6 +15,16 @@ BNCache() = BNCache(nothing, nothing)
15
15
16
16
@inline _wsize (y) = (fill (1 , ndims (y)- 2 )... , size (y)[end - 1 ], 1 )
17
17
18
+ function batchnorm (g:: Nothing , b:: Nothing , x:: DenseCuArray ,
19
+ running_mean, running_var, momentum;
20
+ kws... )
21
+ g = fill! (similar (x, size (ndims (x)- 1 )), 1 )
22
+ b = fill! (similar (x, size (ndims (x)- 1 )), 0 )
23
+
24
+ batchnorm (g, b, x, running_mean, running_var, momentum;
25
+ kws... )
26
+ end
27
+
18
28
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
19
29
# so reshape a 2D Tensor into 4D
20
30
batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T,2} ,
@@ -37,6 +47,7 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray
37
47
alpha = T (1 ), beta = T (0 ),
38
48
eps = T (1e-5 ),
39
49
training = true ,
50
+ affine = true ,
40
51
track_stats = true ) where T<: Union{Float32, Float64}
41
52
dims = _wsize (x)
42
53
if eps < CUDNN_BN_MIN_EPSILON
@@ -73,6 +84,13 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray
73
84
return y
74
85
end
75
86
87
+ function ∇batchnorm (g:: Nothing , b:: Nothing , x:: DenseCuArray , dy:: DenseCuArray ,
88
+ running_mean, running_var, momentum; kws... )
89
+ g = fill! (similar (x, size (ndims (x)- 1 )), 1 )
90
+ b = fill! (similar (x, size (ndims (x)- 1 )), 0 )
91
+ ∇batchnorm (g, b, x, dy, running_mean, running_var, momentum; kws... )
92
+ end
93
+
76
94
function ∇batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T, 2} , dy:: DenseCuArray{T, 2} ,
77
95
running_mean, running_var, momentum;
78
96
kws... ) where T<: Union{Float32, Float64}
@@ -81,14 +99,20 @@ function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T,
81
99
(dg, db, dropdims (dx, dims = (1 , 2 )))
82
100
end
83
101
102
+
84
103
function ∇batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T} , dy:: DenseCuArray{T} ,
85
104
running_mean, running_var, momentum;
86
- kws... ) where T<: Union{Float32, Float64}
105
+ affine = true , kws... ) where T<: Union{Float32, Float64}
87
106
dg = similar (g)
88
107
db = similar (b)
89
108
dx = similar (x)
90
109
cudnnBNBackward! (dg, g, db, dx, x, dy, running_mean, running_var, T (momentum); kws... )
91
- (dg, db, dx)
110
+ if affine
111
+ (dg, db, dx)
112
+ else
113
+ # CUDNN always calculates dg and db, therefore we just have to drop them
114
+ (nothing , nothing , dx)
115
+ end
92
116
end
93
117
94
118
function cudnnBNBackward! (dg:: DenseCuArray{T} , g:: DenseCuArray{T} , db:: DenseCuArray{T} ,
@@ -104,29 +128,38 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
104
128
running_var = CU_NULL
105
129
end
106
130
107
- if training
108
- xd = cudnnTensorDescriptor (x)
109
- dyd = cudnnTensorDescriptor (dy)
110
- dxd = cudnnTensorDescriptor (dx)
111
- gd = cudnnTensorDescriptor (CUDNN_TENSOR_NCHW, cudnnDataType (T), Cint (length (_wsize (x))), dim4 (_wsize (x),Val (CUDNN_TENSOR_NCHW)))
112
- if cache != = nothing
113
- mean, ivar = cache. mean, cache. ivar
114
- info (" mean and ivar are fetched from the cache" )
115
- else
116
- mean, ivar = CU_NULL, CU_NULL
117
- end
131
+ xd = cudnnTensorDescriptor (x)
132
+ dyd = cudnnTensorDescriptor (dy)
133
+ dxd = cudnnTensorDescriptor (dx)
134
+ gd = cudnnTensorDescriptor (CUDNN_TENSOR_NCHW, cudnnDataType (T), Cint (length (_wsize (x))), dim4 (_wsize (x),Val (CUDNN_TENSOR_NCHW)))
135
+ if cache != = nothing
136
+ mean, ivar = cache. mean, cache. ivar
137
+ # info("mean and ivar are fetched from the cache")
138
+ else
139
+ mean, ivar = CU_NULL, CU_NULL
140
+ end
118
141
119
- if eps < CUDNN_BN_MIN_EPSILON
120
- eps = CUDNN_BN_MIN_EPSILON
121
- end
142
+ if eps < CUDNN_BN_MIN_EPSILON
143
+ eps = CUDNN_BN_MIN_EPSILON
144
+ end
122
145
123
- cudnnBatchNormalizationBackward (handle (), CUDNN_BATCHNORM_SPATIAL, scalingParameter (T, alpha), scalingParameter (T, beta), scalingParameter (T, dalpha), scalingParameter (T, dbeta), xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
146
+ if training
147
+ cudnnBatchNormalizationBackward (handle (), CUDNN_BATCHNORM_SPATIAL,
148
+ scalingParameter (T, alpha), scalingParameter (T, beta), scalingParameter (T, dalpha), scalingParameter (T, dbeta),
149
+ xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps,
150
+ mean, ivar)
124
151
else
125
- ivar = 1 ./ sqrt .(reshape (running_var, _wsize (x)) .+ eps)
126
- dx .= dy .* reshape (g, _wsize (x)) .* ivar
127
- rdims = ((1 : ndims (x)- 2 ). .. , ndims (x))
128
- dg .= vec (sum (dy .* (x .- reshape (running_mean, _wsize (x))) .* ivar, dims= rdims))
129
- db .= vec (sum (dy, dims= rdims))
152
+ cudnnBatchNormalizationBackward (handle (), CUDNN_BATCHNORM_SPATIAL,
153
+ scalingParameter (T, alpha), scalingParameter (T, beta), scalingParameter (T, dalpha), scalingParameter (T, dbeta),
154
+ xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps,
155
+ running_mean, running_var)
130
156
end
131
157
end
132
158
159
+ function rrule (:: typeof (batchnorm), g, b, x, running_mean, running_var, momentum; kws... )
160
+ y = batchnorm (g, b, x, running_mean, running_var, momentum; kws... )
161
+ function batchnorm_pullback (Δ)
162
+ NoTangent (), ∇batchnorm (g, b, x, Δ, running_mean, running_var, momentum; kws... )... , NoTangent (), NoTangent (), NoTangent ()
163
+ end
164
+ y, batchnorm_pullback
165
+ end
0 commit comments