Skip to content

Commit 8a12129

Browse files
authored
Try to get tests passing
1 parent 397def0 commit 8a12129

File tree

2 files changed

+16
-20
lines changed

2 files changed

+16
-20
lines changed

ext/NNlibCUDA/src/cudnn/batchnorm.jl

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@ end
1313

1414
BNCache() = BNCache(nothing, nothing)
1515

16-
@inline _wsize(y) = (fill(1, ndims(y)-2)..., size(y)[end-1], 1)
16+
@inline _wsize(y) = ntuple(i -> i == ndims(y) - 1 ? 1 : size(y, i), ndims(y))
1717

1818
function batchnorm(g::Nothing, b::Nothing, x::DenseCuArray,
1919
running_mean, running_var, momentum;
2020
kws...)
21-
g = fill!(similar(x, size(ndims(x)-1)), 1)
22-
b = fill!(similar(x, size(ndims(x)-1)), 0)
21+
affine_sz = _wsize(x)
22+
g = fill!(similar(x, affine_sz), 1)
23+
b = fill!(similar(x, affine_sz), 0)
2324

2425
batchnorm(g, b, x, running_mean, running_var, momentum;
2526
kws...)
@@ -86,8 +87,9 @@ end
8687

8788
function ∇batchnorm(g::Nothing, b::Nothing, x::DenseCuArray, dy::DenseCuArray,
8889
running_mean, running_var, momentum; kws...)
89-
g = fill!(similar(x, size(ndims(x)-1)), 1)
90-
b = fill!(similar(x, size(ndims(x)-1)), 0)
90+
affine_sz = _wsize(x)
91+
g = fill!(similar(x, affine_sz), 1)
92+
b = fill!(similar(x, affine_sz), 0)
9193
∇batchnorm(g, b, x, dy, running_mean, running_var, momentum; kws...)
9294
end
9395

@@ -134,7 +136,7 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
134136
gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(_wsize(x))), dim4(_wsize(x),Val(CUDNN_TENSOR_NCHW)))
135137
if cache !== nothing
136138
mean, ivar = cache.mean, cache.ivar
137-
# info("mean and ivar are fetched from the cache")
139+
@debug "mean and ivar are fetched from the cache"
138140
else
139141
mean, ivar = CU_NULL, CU_NULL
140142
end
@@ -143,23 +145,17 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
143145
eps = CUDNN_BN_MIN_EPSILON
144146
end
145147

146-
if training
147-
cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL,
148-
scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta),
149-
xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps,
150-
mean, ivar)
151-
else
152-
cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL,
153-
scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta),
154-
xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps,
155-
running_mean, running_var)
156-
end
148+
cudnnBatchNormalizationBackward(handle(), CUDNN_BATCHNORM_SPATIAL,
149+
scalingParameter(T, alpha), scalingParameter(T, beta), scalingParameter(T, dalpha), scalingParameter(T, dbeta),
150+
xd, x, dyd, dy, dxd, dx, gd, g, dg, db, eps,
151+
mean, ivar)
157152
end
158153

159154
function rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kws...)
160155
y = batchnorm(g, b, x, running_mean, running_var, momentum; kws...)
161156
function batchnorm_pullback(Δ)
162-
NoTangent(), ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kws...)..., NoTangent(), NoTangent(), NoTangent()
157+
dg, db, dx = ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kws...)
158+
NoTangent(), something(dg, NoTangent()), something(db, NoTangent()), dx, NoTangent(), NoTangent(), NoTangent()
163159
end
164160
y, batchnorm_pullback
165161
end

ext/NNlibCUDA/test/batchnorm.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
@testset "track_stats=false" begin
1010
for training in (false, true)
11-
NNlibCUDA.batchnorm(v, v, m, nothing, nothing, 1.0; training=training)
12-
NNlibCUDA.∇batchnorm(v, v, m, m, nothing, nothing, 1.0; training=training)
11+
NNlibCUDA.batchnorm(v, v, m, nothing, nothing, 1.0; training=training, track_stats=false)
12+
NNlibCUDA.∇batchnorm(v, v, m, m, nothing, nothing, 1.0; training=training, track_stats=false)
1313
end
1414
end
1515
end

0 commit comments

Comments
 (0)