Skip to content

Commit 397def0

Browse files
add affine option in batchnorm
1 parent 4221f71 commit 397def0

File tree

3 files changed

+59
-22
lines changed

3 files changed

+59
-22
lines changed

ext/NNlibCUDA/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ version = "0.1.11"
44

55
[deps]
66
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
7+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
910
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1011
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1112

1213
[compat]
14+
ChainRulesCore = "1"
1315
CUDA = "3.3.1"
1416
NNlib = "0.7.31"
1517
julia = "1.6"

ext/NNlibCUDA/src/NNlibCUDA.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ module NNlibCUDA
33
using NNlib
44
using CUDA
55
using Random, Statistics
6+
using ChainRulesCore: NoTangent, ZeroTangent
7+
import ChainRulesCore: rrule
68

79
const IntOrIntTuple = Union{Integer, NTuple{N,<:Integer} where N}
810

ext/NNlibCUDA/src/cudnn/batchnorm.jl

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@ BNCache() = BNCache(nothing, nothing)
1515

1616
@inline _wsize(y) = (fill(1, ndims(y)-2)..., size(y)[end-1], 1)
1717

18+
function batchnorm(g::Nothing, b::Nothing, x::DenseCuArray,
19+
running_mean, running_var, momentum;
20+
kws...)
21+
g = fill!(similar(x, size(ndims(x)-1)), 1)
22+
b = fill!(similar(x, size(ndims(x)-1)), 0)
23+
24+
batchnorm(g, b, x, running_mean, running_var, momentum;
25+
kws...)
26+
end
27+
1828
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
1929
# so reshape a 2D Tensor into 4D
2030
batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T,2},
@@ -37,6 +47,7 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray
3747
alpha = T(1), beta = T(0),
3848
eps = T(1e-5),
3949
training = true,
50+
affine = true,
4051
track_stats = true) where T<:Union{Float32, Float64}
4152
dims = _wsize(x)
4253
if eps < CUDNN_BN_MIN_EPSILON
@@ -73,6 +84,13 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray
7384
return y
7485
end
7586

87+
function ∇batchnorm(g::Nothing, b::Nothing, x::DenseCuArray, dy::DenseCuArray,
88+
running_mean, running_var, momentum; kws...)
89+
g = fill!(similar(x, size(ndims(x)-1)), 1)
90+
b = fill!(similar(x, size(ndims(x)-1)), 0)
91+
∇batchnorm(g, b, x, dy, running_mean, running_var, momentum; kws...)
92+
end
93+
7694
function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T, 2}, dy::DenseCuArray{T, 2},
7795
running_mean, running_var, momentum;
7896
kws...) where T<:Union{Float32, Float64}
@@ -81,14 +99,20 @@ function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T,
8199
(dg, db, dropdims(dx, dims = (1, 2)))
82100
end
83101

102+
84103
function ∇batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArray{T},
85104
running_mean, running_var, momentum;
86-
kws...) where T<:Union{Float32, Float64}
105+
affine=true, kws...) where T<:Union{Float32, Float64}
87106
dg = similar(g)
88107
db = similar(b)
89108
dx = similar(x)
90109
cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum); kws...)
91-
(dg, db, dx)
110+
if affine
111+
(dg, db, dx)
112+
else
113+
# CUDNN always calculates dg and db, therefore we just have to drop them
114+
(nothing, nothing, dx)
115+
end
92116
end
93117

94118
function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuArray{T},
@@ -104,29 +128,38 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
104128
running_var = CU_NULL
105129
end
106130

107-
if training
108-
xd = cudnnTensorDescriptor(x)
109-
dyd = cudnnTensorDescriptor(dy)
110-
dxd = cudnnTensorDescriptor(dx)
111-
gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), dim4(_wsize(x),Val(CUDNN_TENSOR_NCHW)))
112-
if cache !== nothing
113-
mean, ivar = cache.mean, cache.ivar
114-
info("mean and ivar are fetched from the cache")
115-
else
116-
mean, ivar = CU_NULL, CU_NULL
117-
end
131+
xd = cudnnTensorDescriptor(x)
132+
dyd = cudnnTensorDescriptor(dy)
133+
dxd = cudnnTensorDescriptor(dx)
134+
gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), dim4(_wsize(x),Val(CUDNN_TENSOR_NCHW)))
135+
if cache !== nothing
136+
mean, ivar = cache.mean, cache.ivar
137+
# info("mean and ivar are fetched from the cache")
138+
else
139+
mean, ivar = CU_NULL, CU_NULL
140+
end
118141

119-
if eps < CUDNN_BN_MIN_EPSILON
120-
eps = CUDNN_BN_MIN_EPSILON
121-
end
142+
if eps < CUDNN_BN_MIN_EPSILON
143+
eps = CUDNN_BN_MIN_EPSILON
144+
end
122145

123-
cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL, scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta), xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps, mean, ivar)
146+
if training
147+
cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL,
148+
scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta),
149+
xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps,
150+
mean, ivar)
124151
else
125-
ivar = 1 ./ sqrt.(reshape(running_var, _wsize(x)) .+ eps)
126-
dx .= dy .* reshape(g, _wsize(x)) .* ivar
127-
rdims = ((1:ndims(x)-2)..., ndims(x))
128-
dg .= vec(sum(dy .* (x .- reshape(running_mean, _wsize(x))) .* ivar, dims=rdims))
129-
db .= vec(sum(dy, dims=rdims))
152+
cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL,
153+
scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta),
154+
xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps,
155+
running_mean, running_var)
130156
end
131157
end
132158

159+
function rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kws...)
160+
y = batchnorm(g, b, x, running_mean, running_var, momentum; kws...)
161+
function batchnorm_pullback(Δ)
162+
NoTangent(), ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kws...)..., NoTangent(), NoTangent(), NoTangent()
163+
end
164+
y, batchnorm_pullback
165+
end

0 commit comments

Comments
 (0)