File tree Expand file tree Collapse file tree 1 file changed +3
-5
lines changed Expand file tree Collapse file tree 1 file changed +3
-5
lines changed Original file line number Diff line number Diff line change @@ -78,7 +78,7 @@ ZygoteRules.@adjoint function (f::FastDense)(x,p)
78
78
79
79
y = f. σ .(r)
80
80
81
- if typeof (f. σ) <: typeof (tanh)
81
+ if typeof (f. σ) <: typeof (tanh) || typeof (f . σ) <: typeof (identity)
82
82
ifgpufree (r)
83
83
end
84
84
@@ -87,25 +87,23 @@ ZygoteRules.@adjoint function (f::FastDense)(x,p)
87
87
zbar = ȳ .* (1 .- y.^ 2 )
88
88
elseif typeof (f. σ) <: typeof (identity)
89
89
zbar = ȳ
90
- ifgpufree (r)
91
90
else
92
91
zbar = ȳ .* ForwardDiff. derivative .(f. σ,r)
93
- ifgpufree (r)
94
92
end
95
- ifgpufree (y)
96
93
Wbar = zbar * x'
97
94
bbar = zbar
98
95
xbar = W' * zbar
99
96
pbar = typeof (bbar) <: AbstractVector ?
100
97
vec (vcat (vec (Wbar),bbar)) :
101
98
vec (vcat (vec (Wbar),sum (bbar,dims= 2 )))
102
- ifgpufree (Wbar); ifgpufree (bbar); ifgpufree (ȳ)
99
+ ifgpufree (Wbar); ifgpufree (bbar)
103
100
nothing ,xbar,pbar
104
101
end
105
102
y,FastDense_adjoint
106
103
end
107
104
paramlength (f:: FastDense ) = f. out* (f. in+ f. bias)
108
105
initial_params (f:: FastDense ) = f. initial_params ()
106
+
109
107
"""
110
108
StaticDense(in,out,activation=identity;
111
109
initW = Flux.glorot_uniform, initb = Flux.zeros)
You can’t perform that action at this time.
0 commit comments