Skip to content

Commit 8ca2a4b

Browse files
authored
Fix input modification for FlatRule and WSquareRule (#93)
1 parent 239deae commit 8ca2a4b

File tree

8 files changed

+3
-2
lines changed

8 files changed

+3
-2
lines changed

src/lrp/rules.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ function lrp!(Rₖ, rule::R, layer::L, aₖ, Rₖ₊₁) where {R<:AbstractLRPRu
1313
check_compat(rule, layer)
1414
reset! = get_layer_resetter(rule, layer)
1515
modify_layer!(rule, layer)
16-
ãₖ₊₁, pullback = Zygote.pullback(preactivation(layer), modify_input(rule, aₖ))
17-
Rₖ .= aₖ .* only(pullback(Rₖ₊₁ ./ modify_denominator(rule, ãₖ₊₁)))
16+
ãₖ = modify_input(rule, aₖ)
17+
ãₖ₊₁, pullback = Zygote.pullback(preactivation(layer), ãₖ)
18+
Rₖ .= ãₖ .* only(pullback(Rₖ₊₁ ./ modify_denominator(rule, ãₖ₊₁)))
1819
reset!()
1920
return nothing
2021
end
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)