Skip to content

Commit d811f89

Browse files
authored
Do not free memory you do not own. (#503)
1 parent 312d685 commit d811f89

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

src/fast_layers.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ ZygoteRules.@adjoint function (f::FastDense)(x,p)
7878

7979
y = f.σ.(r)
8080

81-
if typeof(f.σ) <: typeof(tanh)
81+
if typeof(f.σ) <: typeof(tanh) || typeof(f.σ) <: typeof(identity)
8282
ifgpufree(r)
8383
end
8484

@@ -87,25 +87,23 @@ ZygoteRules.@adjoint function (f::FastDense)(x,p)
8787
zbar =.* (1 .- y.^2)
8888
elseif typeof(f.σ) <: typeof(identity)
8989
zbar =
90-
ifgpufree(r)
9190
else
9291
zbar =.* ForwardDiff.derivative.(f.σ,r)
93-
ifgpufree(r)
9492
end
95-
ifgpufree(y)
9693
Wbar = zbar * x'
9794
bbar = zbar
9895
xbar = W' * zbar
9996
pbar = typeof(bbar) <: AbstractVector ?
10097
vec(vcat(vec(Wbar),bbar)) :
10198
vec(vcat(vec(Wbar),sum(bbar,dims=2)))
102-
ifgpufree(Wbar); ifgpufree(bbar); ifgpufree(ȳ)
99+
ifgpufree(Wbar); ifgpufree(bbar)
103100
nothing,xbar,pbar
104101
end
105102
y,FastDense_adjoint
106103
end
107104
paramlength(f::FastDense) = f.out*(f.in+f.bias)
108105
initial_params(f::FastDense) = f.initial_params()
106+
109107
"""
110108
StaticDense(in,out,activation=identity;
111109
initW = Flux.glorot_uniform, initb = Flux.zeros)

0 commit comments

Comments
 (0)