Skip to content

Commit 9c32265

Browse files
authored
track_stats in the right place
1 parent 8a12129 commit 9c32265

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

ext/NNlibCUDA/src/cudnn/batchnorm.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray
5959
yd = cudnnTensorDescriptor(y)
6060
gd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), dim4(dims,Val(CUDNN_TENSOR_NCHW)))
6161

62-
if !track_stats
63-
running_mean = CU_NULL
64-
running_var = CU_NULL
65-
end
6662

6763
if training
64+
if !track_stats
65+
running_mean = CU_NULL
66+
running_var = CU_NULL
67+
end
6868
if cache !== nothing
6969
mean = zeros(CuArray{T}, dims...)
7070
ivar = ones(CuArray{T}, dims...)

0 commit comments

Comments
 (0)