Skip to content

Commit 32919ac

Browse files
Merge pull request #40 from FluxML/DhairyaLGandhi-patch-1
Use new `PoolDims` constructor for NNlib v0.8
2 parents 6ac54c0 + 998714c commit 32919ac

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

ext/NNlibCUDA/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "NNlibCUDA"
22
uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d"
3-
version = "0.1.11"
3+
version = "0.2.0"
44

55
[deps]
66
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -11,7 +11,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111

1212
[compat]
1313
CUDA = "3.3.1"
14-
NNlib = "0.7.31"
14+
NNlib = "0.8"
1515
julia = "1.6"
1616

1717
[extras]

ext/NNlibCUDA/src/cudnn/pooling.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,14 @@ end
4141

4242
add1d(x) = reshape(x, 1, size(x)...)
4343

44-
fix_pooldims_1d(pdims::PoolDims{1,K,S,P,D}) where {K,S,P,D} =
45-
PoolDims{2,(1,K...),(1,S...),(0,0,P...),(1,D...)}((1,pdims.I...), pdims.C_in)
44+
function fix_pooldims_1d(pdims::PoolDims{1,K,S,P,D}) where {K,S,P,D}
45+
PoolDims{2, K + 1, S + 1, P + 2, D + 1}((1, NNlib.input_size(pdims)...),
46+
(1, NNlib.kernel_size(pdims)...),
47+
NNlib.channels_in(pdims),
48+
(1, NNlib.stride(pdims)...),
49+
(0, 0, NNlib.padding(pdims)...),
50+
(1, NNlib.dilation(pdims)...))
51+
end
4652

4753
function maxpool!(y::DenseCuArray{T,3}, x::DenseCuArray{T,3}, pdims::PoolDims) where T<:CUDNNFloat
4854
maxpool!(add1d(y), add1d(x), fix_pooldims_1d(pdims))

0 commit comments

Comments
 (0)