@@ -44,7 +44,7 @@ function cudnnBNForward!(y::DenseCuArray{T}, g::DenseCuArray{T}, b::DenseCuArray
44
44
eps = T (1e-5 ),
45
45
training = true ,
46
46
affine = true ,
47
- track_stats = true ) where T<: Union{Float32, Float64}
47
+ track_stats = true ) where T<: CUDNNFloat
48
48
dims = _wsize (x)
49
49
if eps < CUDNN_BN_MIN_EPSILON
50
50
@warn " eps $eps is too small for CuDNN, setting to CUDNN_BN_MIN_EPSILON=$CUDNN_BN_MIN_EPSILON "
99
99
100
100
function ∇batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T, 2} , dy:: DenseCuArray{T, 2} ,
101
101
running_mean, running_var, momentum;
102
- kws... ) where T<: Union{Float32, Float64}
102
+ kws... ) where T<: CUDNNFloat
103
103
dg, db, dx = ∇batchnorm (g, b, reshape (x, 1 , 1 , size (x, 1 ), size (x, 2 )), reshape (dy, 1 , 1 , size (dy, 1 ),
104
104
size (dy, 2 )), running_mean, running_var, momentum; kws... )
105
105
(dg, db, dropdims (dx, dims = (1 , 2 )))
108
108
109
109
function ∇batchnorm (g:: DenseCuArray{T} , b:: DenseCuArray{T} , x:: DenseCuArray{T} , dy:: DenseCuArray{T} ,
110
110
running_mean, running_var, momentum;
111
- affine= true , kws... ) where T<: Union{Float32, Float64}
111
+ affine= true , kws... ) where T<: CUDNNFloat
112
112
dg = similar (g)
113
113
db = similar (b)
114
114
dx = similar (x)
@@ -127,7 +127,7 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
127
127
momentum; cache = nothing , eps = T (1e-5 ),
128
128
alpha = T (1 ), beta = T (0 ),
129
129
dalpha = T (1 ), dbeta = T (0 ), training = true ,
130
- track_stats = true ) where T<: Union{Float32, Float64}
130
+ track_stats = true ) where T<: CUDNNFloat
131
131
if ! track_stats
132
132
running_mean = CU_NULL
133
133
running_var = CU_NULL
0 commit comments