Skip to content

Commit 05331f1

Browse files
Use input_size correctly for v0.8
1 parent 6ac54c0 commit 05331f1

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

ext/NNlibCUDA/src/cudnn/pooling.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ end
4141

4242
add1d(x) = reshape(x, 1, size(x)...)
4343

44-
fix_pooldims_1d(pdims::PoolDims{1,K,S,P,D}) where {K,S,P,D} =
45-
PoolDims{2,(1,K...),(1,S...),(0,0,P...),(1,D...)}((1,pdims.I...), pdims.C_in)
44+
function fix_pooldims_1d(pdims::PoolDims{1,K,S,P,D}) where {K,S,P,D}
45+
PoolDims{2,(1,K...),(1,S...),(0,0,P...),(1,D...)}((1,NNlib.input_size(pdims)...), pdims.C_in)
46+
end
4647

4748
function maxpool!(y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat
4849
maxpool!(add1d(y), add1d(x), fix_pooldims_1d(pdims))

0 commit comments

Comments
 (0)