Skip to content

Commit b16c745

Browse files
authored
one more
1 parent 7a0dbe8 commit b16c745

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

ext/NNlibCUDA/src/cudnn/batchnorm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T,2},
3333
end
3434

3535
function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::Union{DenseCuArray{T,4},DenseCuArray{T,5}},
36-
running_mean, running_var, momentum; kws...) where T<:Union{Float32, Float64}
36+
running_mean, running_var, momentum; kws...) where T<:CUDNNFloat
3737
cudnnBNForward!(similar(x), g, b, x, running_mean, running_var, momentum; kws...)
3838
end
3939

0 commit comments

Comments
 (0)