Skip to content

Commit d29ab6e

Browse files
authored
Merge pull request #53 from maxfreu/patch-1
save allocs during algorithm search
2 parents 24cd95d + a5e4d55 commit d29ab6e

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

ext/NNlibCUDA/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1010
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1111

1212
[compat]
13-
CUDA = "3.3.1"
14-
NNlib = "0.8.6"
13+
CUDA = "3.11"
14+
NNlib = "0.8.7"
1515
julia = "1.6"
1616

1717
[extras]

ext/NNlibCUDA/src/cudnn/conv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ function ∇conv_data!(dx::DenseCuArray{T}, dy::DenseCuArray{T}, w::DenseCuArray
9797
alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);
9898
convDesc, dx, depad = cudnnConvolutionDescriptorAndPaddedInput(cdims, dx)
9999
xDesc, yDesc, wDesc = cudnnTensorDescriptor(dx), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(w)
100-
p = cudnnConvolutionBwdDataAlgoPerf(wDesc, w, yDesc, dy, convDesc, xDesc, dx)
100+
p = cudnnConvolutionBwdDataAlgoPerf(wDesc, w, yDesc, dy, convDesc, xDesc, dx, beta!=0)
101101
with_workspace(p.memory) do workspace
102102
cudnnConvolutionBackwardData(handle(), alpha, wDesc, w, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, xDesc, dx)
103103
end
@@ -115,7 +115,7 @@ function ∇conv_filter!(dw::DenseCuArray{T}, x::DenseCuArray{T}, dy::DenseCuArr
115115
alpha, beta = scalingParameter(T,alpha), scalingParameter(T,beta);
116116
convDesc, x, _ = cudnnConvolutionDescriptorAndPaddedInput(cdims, x)
117117
xDesc, yDesc, wDesc = cudnnTensorDescriptor(x), cudnnTensorDescriptor(dy), cudnnFilterDescriptor(dw)
118-
p = cudnnConvolutionBwdFilterAlgoPerf(xDesc, x, yDesc, dy, convDesc, wDesc, dw);
118+
p = cudnnConvolutionBwdFilterAlgoPerf(xDesc, x, yDesc, dy, convDesc, wDesc, dw, beta!=0);
119119
with_workspace(p.memory) do workspace
120120
cudnnConvolutionBackwardFilter(handle(), alpha, xDesc, x, yDesc, dy, convDesc, p.algo, workspace, sizeof(workspace), beta, wDesc, dw);
121121
end

0 commit comments

Comments
 (0)