Skip to content

Commit 998714c

Browse files
fixes
1 parent 85947af commit 998714c

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

ext/NNlibCUDA/src/cudnn/pooling.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ add1d(x) = reshape(x, 1, size(x)...)
4343

4444
function fix_pooldims_1d(pdims::PoolDims{1,K,S,P,D}) where {K,S,P,D}
4545
PoolDims{2, K + 1, S + 1, P + 2, D + 1}((1, NNlib.input_size(pdims)...),
46-
(1, K...),
46+
(1, NNlib.kernel_size(pdims)...),
4747
NNlib.channels_in(pdims),
48-
(1, S...),
49-
(0, 0, P...),
50-
(1, D...))
48+
(1, NNlib.stride(pdims)...),
49+
(0, 0, NNlib.padding(pdims)...),
50+
(1, NNlib.dilation(pdims)...))
5151
end
5252

5353
function maxpool!(y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat

0 commit comments

Comments
 (0)