Skip to content

Commit 256a4fb

Browse files
committed
More fixups
1 parent 6e64553 commit 256a4fb

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

src/enzyme.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
122122
end
123123

124124
ddsts = dst.dval
125-
dsrcs = src.dval
125+
dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval
126126

127127
if EnzymeCore.EnzymeRules.width(config) == 1
128128
ddsts = (ddsts,)
@@ -182,7 +182,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
182182
end
183183

184184
ddsts = dst.dval
185-
dsrcs = src.dval
185+
dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval
186186

187187
if EnzymeCore.EnzymeRules.width(config) == 1
188188
ddsts = (ddsts,)
@@ -322,12 +322,6 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
322322
keep = nothing
323323
end
324324

325-
# Cache idx if its overwritten
326-
cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4]
327-
&& !(typeof(src) <: EnzymeCore.Const)
328-
&& !(typeof(dst) <: EnzymeCore.Const)
329-
) ? copy(idx.val) : nothing
330-
331325
return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, keep)
332326
end
333327

@@ -336,7 +330,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
336330
val = convert(T, 1/(1-p.val))
337331

338332
ddsts = dst.dval
339-
dsrcs = src.dval
333+
dsrcs = (typeof(src) <: EnzymeCore.Const) ? ddsts : src.dval
340334

341335
if EnzymeCore.EnzymeRules.width(config) == 1
342336
ddsts = (ddsts,)

0 commit comments

Comments
 (0)