Skip to content

Commit 8f428c8

Browse files
committed
Rename internal lrp! variable names to match literature
Add distinction between activations `a` and pre-activations `z`.
1 parent 8ca2a4b commit 8f428c8

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

src/lrp/rules.jl

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ function lrp!(Rₖ, rule::R, layer::L, aₖ, Rₖ₊₁) where {R<:AbstractLRPRu
1414
reset! = get_layer_resetter(rule, layer)
1515
modify_layer!(rule, layer)
1616
ãₖ = modify_input(rule, aₖ)
17-
ãₖ₊₁, pullback = Zygote.pullback(preactivation(layer), ãₖ)
18-
Rₖ .= ãₖ .* only(pullback(Rₖ₊₁ ./ modify_denominator(rule, ãₖ₊₁)))
17+
zₖ₊₁, pullback = Zygote.pullback(preactivation(layer), ãₖ)
18+
Rₖ .= ãₖ .* only(pullback(Rₖ₊₁ ./ modify_denominator(rule, zₖ₊₁)))
1919
reset!()
2020
return nothing
2121
end
@@ -235,24 +235,24 @@ function lrp!(Rₖ, rule::ZBoxRule, layer::L, aₖ, Rₖ₊₁) where {L}
235235
h = zbox_input(aₖ, rule.high)
236236

237237
# Compute pullback for W, b
238-
aₖ₊₁, pullback = Zygote.pullback(preactivation(layer), aₖ)
238+
zₖ₊₁, pullback = Zygote.pullback(preactivation(layer), aₖ)
239239

240240
# Compute pullback for W⁺, b⁺
241241
modify_layer!(Val(:keep_positive), layer)
242-
aₖ₊₁⁺, pullback⁺ = Zygote.pullback(preactivation(layer), l)
242+
zₖ₊₁⁺, pullback⁺ = Zygote.pullback(preactivation(layer), l)
243243
reset!()
244244

245245
# Compute pullback for W⁻, b⁻
246246
modify_layer!(Val(:keep_negative), layer)
247-
aₖ₊₁⁻, pullback⁻ = Zygote.pullback(preactivation(layer), h)
247+
zₖ₊₁⁻, pullback⁻ = Zygote.pullback(preactivation(layer), h)
248248

249249
# Evaluate pullbacks
250-
y = Rₖ₊₁ ./ modify_denominator(rule, aₖ₊₁ - aₖ₊₁⁺ - aₖ₊₁⁻)
251-
Rₖ .= -h .* only(pullback⁻(y))
250+
sₖ₊₁ = Rₖ₊₁ ./ modify_denominator(rule, zₖ₊₁ - zₖ₊₁⁺ - zₖ₊₁⁻)
251+
Rₖ .= -h .* only(pullback⁻(sₖ₊₁))
252252
reset!() # re-modify mutated pullback
253-
Rₖ .+= aₖ .* only(pullback(y))
253+
Rₖ .+= aₖ .* only(pullback(sₖ₊₁))
254254
modify_layer!(Val(:keep_positive), layer) # re-modify mutated pullback
255-
Rₖ .-= l .* only(pullback⁺(y))
255+
Rₖ .-= l .* only(pullback⁺(sₖ₊₁))
256256
reset!()
257257
return nothing
258258
end
@@ -300,30 +300,30 @@ function lrp!(Rₖ, rule::AlphaBetaRule, layer::L, aₖ, Rₖ₊₁) where {L}
300300

301301
# α: positive contributions
302302
modify_layer!(Val(:keep_negative_zero_bias), layer)
303-
aₖ₊₁ᵅ⁻, pullbackᵅ⁻ = Zygote.pullback(preactivation(layer), aₖ⁻)
303+
zₖ₊₁ᵅ⁻, pullbackᵅ⁻ = Zygote.pullback(preactivation(layer), aₖ⁻)
304304
reset!()
305305
modify_layer!(Val(:keep_positive), layer)
306-
aₖ₊₁ᵅ⁺, pullbackᵅ⁺ = Zygote.pullback(preactivation(layer), aₖ⁺)
306+
zₖ₊₁ᵅ⁺, pullbackᵅ⁺ = Zygote.pullback(preactivation(layer), aₖ⁺)
307307
# evaluate pullbacks
308-
yᵅ = Rₖ₊₁ ./ modify_denominator(rule, aₖ₊₁ᵅ⁺ + aₖ₊₁ᵅ⁻)
309-
Rₖ .= rule.α .* aₖ⁺ .* only(pullbackᵅ⁺(yᵅ))
308+
sₖ₊₁ᵅ = Rₖ₊₁ ./ modify_denominator(rule, zₖ₊₁ᵅ⁺ + zₖ₊₁ᵅ⁻)
309+
Rₖ .= rule.α .* aₖ⁺ .* only(pullbackᵅ⁺(sₖ₊₁ᵅ))
310310
reset!()
311311
modify_layer!(Val(:keep_negative_zero_bias), layer) # re-modify mutated pullback
312-
Rₖ .+= rule.α .* aₖ⁻ .* only(pullbackᵅ⁻(yᵅ))
312+
Rₖ .+= rule.α .* aₖ⁻ .* only(pullbackᵅ⁻(sₖ₊₁ᵅ))
313313
reset!()
314314

315315
# β: Negative contributions
316316
modify_layer!(Val(:keep_positive_zero_bias), layer)
317-
aₖ₊₁ᵝ⁻, pullbackᵝ⁻ = Zygote.pullback(preactivation(layer), aₖ⁻) #
317+
zₖ₊₁ᵝ⁻, pullbackᵝ⁻ = Zygote.pullback(preactivation(layer), aₖ⁻) #
318318
reset!()
319319
modify_layer!(Val(:keep_negative), layer)
320-
aₖ₊₁ᵝ⁺, pullbackᵝ⁺ = Zygote.pullback(preactivation(layer), aₖ⁺)
320+
zₖ₊₁ᵝ⁺, pullbackᵝ⁺ = Zygote.pullback(preactivation(layer), aₖ⁺)
321321
# evaluate pullbacks
322-
yᵝ = Rₖ₊₁ ./ modify_denominator(rule, aₖ₊₁ᵝ⁺ + aₖ₊₁ᵝ⁻)
323-
Rₖ .-= rule.β .* aₖ⁺ .* only(pullbackᵝ⁺(yᵝ))
322+
sₖ₊₁ᵝ = Rₖ₊₁ ./ modify_denominator(rule, zₖ₊₁ᵝ⁺ + zₖ₊₁ᵝ⁻)
323+
Rₖ .-= rule.β .* aₖ⁺ .* only(pullbackᵝ⁺(sₖ₊₁ᵝ))
324324
reset!()
325325
modify_layer!(Val(:keep_positive_zero_bias), layer) # re-modify mutated pullback
326-
Rₖ .-= rule.β .* aₖ⁻ .* only(pullbackᵝ⁻(yᵝ))
326+
Rₖ .-= rule.β .* aₖ⁻ .* only(pullbackᵝ⁻(sₖ₊₁ᵝ))
327327
reset!()
328328
return nothing
329329
end
@@ -350,17 +350,17 @@ function lrp!(Rₖ, rule::ZPlusRule, layer::L, aₖ, Rₖ₊₁) where {L}
350350

351351
# Linearize around positive & negative activations (aₖ⁺, aₖ⁻)
352352
modify_layer!(Val(:keep_positive), layer)
353-
aₖ₊₁⁺, pullback⁺ = Zygote.pullback(layer, aₖ⁺)
353+
zₖ₊₁⁺, pullback⁺ = Zygote.pullback(layer, aₖ⁺)
354354
reset!()
355355
modify_layer!(Val(:keep_negative_zero_bias), layer)
356-
aₖ₊₁⁻, pullback⁻ = Zygote.pullback(layer, aₖ⁻)
356+
zₖ₊₁⁻, pullback⁻ = Zygote.pullback(layer, aₖ⁻)
357357

358358
# Evaluate pullbacks
359-
y = Rₖ₊₁ ./ modify_denominator(rule, aₖ₊₁⁺ + aₖ₊₁⁻)
360-
Rₖ .= aₖ⁻ .* only(pullback⁻(y))
359+
sₖ₊₁ = Rₖ₊₁ ./ modify_denominator(rule, zₖ₊₁⁺ + zₖ₊₁⁻)
360+
Rₖ .= aₖ⁻ .* only(pullback⁻(sₖ₊₁))
361361
reset!()
362362
modify_layer!(Val(:keep_positive), layer) # re-modify mutated pullback
363-
Rₖ .+= aₖ⁺ .* only(pullback⁺(y))
363+
Rₖ .+= aₖ⁺ .* only(pullback⁺(sₖ₊₁))
364364
reset!()
365365
return nothing
366366
end
@@ -378,8 +378,8 @@ for R in (ZeroRule, EpsilonRule, GammaRule, WSquareRule, FlatRule)
378378
reset! = get_layer_resetter(rule, layer)
379379
modify_layer!(rule, layer)
380380
ãₖ = modify_input(rule, aₖ)
381-
ãₖ₊₁ = modify_denominator(rule, preactivation(layer, ãₖ))
382-
@tullio Rₖ[j, b] = ãₖ[j, b] * layer.weight[k, j] * Rₖ₊₁[k, b] / ãₖ₊₁[k, b]
381+
zₖ₊₁ = modify_denominator(rule, preactivation(layer, ãₖ))
382+
@tullio Rₖ[j, b] = ãₖ[j, b] * layer.weight[k, j] * Rₖ₊₁[k, b] / zₖ₊₁[k, b]
383383
reset!()
384384
return nothing
385385
end

0 commit comments

Comments
 (0)