Skip to content

Commit 02a4bb2

Browse files
authored
Replace LRP gradient computation with VJP using Zygote.pullback (#72)
1 parent 0d21c1c commit 02a4bb2

File tree

3 files changed

+5
-11
lines changed

3 files changed

+5
-11
lines changed

src/lrp_rules.jl

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,8 @@ function lrp_autodiff!(
1111
Rₖ::T1, rule::R, layer::L, aₖ::T1, Rₖ₊₁::T2
1212
) where {R<:AbstractLRPRule,L,T1,T2}
1313
layerᵨ = modify_layer(rule, layer)
14-
c::T1 = only(
15-
gradient(aₖ) do a
16-
z::T2 = layerᵨ(a)
17-
s = Zygote.@ignore Rₖ₊₁ ./ modify_denominator(rule, z)
18-
z s
19-
end,
20-
)
21-
Rₖ .= aₖ .* c
14+
ãₖ₊₁, back = Zygote.pullback(layerᵨ, aₖ)
15+
Rₖ .= aₖ .* only(back(Rₖ₊₁ ./ modify_denominator(rule, ãₖ₊₁)))
2216
return nothing
2317
end
2418

test/test_heatmaps.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ for r in reducers
2222
end
2323

2424
@test_throws ArgumentError heatmap(A, reduce=:foo)
25-
@test_throws ErrorException heatmap(A, rangescale=:bar)
25+
@test_throws ArgumentError heatmap(A, rangescale=:bar)
2626

2727
B = reshape(A, 2, 2, 3, 1, 1)
2828
@test_throws DomainError heatmap(B)

test/test_rules.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ layers = Dict(
141141
"MeanPool" => MaxPool((3, 3)),
142142
"ConvTranspose" => ConvTranspose((3, 3), 2 => 4; init=pseudorandn),
143143
"CrossCor" => CrossCor((3, 3), 2 => 4; init=pseudorandn),
144-
"flatten" => flatten,
144+
"flatten" => Flux.flatten,
145145
"Dropout" => Dropout(0.2),
146146
"AlphaDropout" => AlphaDropout(0.2),
147147
)
@@ -178,7 +178,7 @@ lrp!(Rₖ, rule::ZBoxRule, w::TestWrapper, aₖ, Rₖ₊₁) = lrp!(Rₖ, rule,
178178
layers = Dict(
179179
"Conv" => (Conv((3, 3), 2 => 4; init=pseudorandn), aₖ),
180180
"Dense_relu" => (Dense(ins_dense, outs_dense, relu; init=pseudorandn), aₖ_dense),
181-
"flatten" => (flatten, aₖ),
181+
"flatten" => (Flux.flatten, aₖ),
182182
)
183183
@testset "Custom layers" begin
184184
for (rulename, rule) in RULES

0 commit comments

Comments
 (0)