Skip to content

Commit 7a0dbe8

Browse files
authored
Missed one
1 parent 65efe8e commit 7a0dbe8

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

ext/NNlibCUDA/src/cudnn/batchnorm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ end
2626
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
2727
# so reshape a 2D Tensor into 4D
2828
function batchnorm(g::DenseCuArray{T}, b::DenseCuArray{T}, x::DenseCuArray{T,2},
29-
running_mean, running_var, momentum; kws...) where T<:Union{Float32, Float64}
29+
running_mean, running_var, momentum; kws...) where T<:CUDNNFloat
3030
x = reshape(x, 1, 1, size(x, 1), size(x, 2))
3131
y = batchnorm(g, b, x, running_mean, running_var, momentum; kws...)
3232
return dropdims(y, dims = (1, 2))

0 commit comments

Comments
 (0)