@@ -14,8 +14,8 @@ function lrp!(Rₖ, rule::R, layer::L, aₖ, Rₖ₊₁) where {R<:AbstractLRPRu
14
14
reset! = get_layer_resetter (rule, layer)
15
15
modify_layer! (rule, layer)
16
16
ãₖ = modify_input (rule, aₖ)
17
- ãₖ ₊₁, pullback = Zygote. pullback (preactivation (layer), ãₖ)
18
- Rₖ .= ãₖ .* only (pullback (Rₖ₊₁ ./ modify_denominator (rule, ãₖ ₊₁)))
17
+ zₖ ₊₁, pullback = Zygote. pullback (preactivation (layer), ãₖ)
18
+ Rₖ .= ãₖ .* only (pullback (Rₖ₊₁ ./ modify_denominator (rule, zₖ ₊₁)))
19
19
reset! ()
20
20
return nothing
21
21
end
@@ -235,24 +235,24 @@ function lrp!(Rₖ, rule::ZBoxRule, layer::L, aₖ, Rₖ₊₁) where {L}
235
235
h = zbox_input (aₖ, rule. high)
236
236
237
237
# Compute pullback for W, b
238
- aₖ ₊₁, pullback = Zygote. pullback (preactivation (layer), aₖ)
238
+ zₖ ₊₁, pullback = Zygote. pullback (preactivation (layer), aₖ)
239
239
240
240
# Compute pullback for W⁺, b⁺
241
241
modify_layer! (Val (:keep_positive ), layer)
242
- aₖ ₊₁⁺, pullback⁺ = Zygote. pullback (preactivation (layer), l)
242
+ zₖ ₊₁⁺, pullback⁺ = Zygote. pullback (preactivation (layer), l)
243
243
reset! ()
244
244
245
245
# Compute pullback for W⁻, b⁻
246
246
modify_layer! (Val (:keep_negative ), layer)
247
- aₖ ₊₁⁻, pullback⁻ = Zygote. pullback (preactivation (layer), h)
247
+ zₖ ₊₁⁻, pullback⁻ = Zygote. pullback (preactivation (layer), h)
248
248
249
249
# Evaluate pullbacks
250
- y = Rₖ₊₁ ./ modify_denominator (rule, aₖ ₊₁ - aₖ ₊₁⁺ - aₖ ₊₁⁻)
251
- Rₖ .= - h .* only (pullback⁻ (y ))
250
+ sₖ₊₁ = Rₖ₊₁ ./ modify_denominator (rule, zₖ ₊₁ - zₖ ₊₁⁺ - zₖ ₊₁⁻)
251
+ Rₖ .= - h .* only (pullback⁻ (sₖ₊₁ ))
252
252
reset! () # re-modify mutated pullback
253
- Rₖ .+ = aₖ .* only (pullback (y ))
253
+ Rₖ .+ = aₖ .* only (pullback (sₖ₊₁ ))
254
254
modify_layer! (Val (:keep_positive ), layer) # re-modify mutated pullback
255
- Rₖ .- = l .* only (pullback⁺ (y ))
255
+ Rₖ .- = l .* only (pullback⁺ (sₖ₊₁ ))
256
256
reset! ()
257
257
return nothing
258
258
end
@@ -300,30 +300,30 @@ function lrp!(Rₖ, rule::AlphaBetaRule, layer::L, aₖ, Rₖ₊₁) where {L}
300
300
301
301
# α: positive contributions
302
302
modify_layer! (Val (:keep_negative_zero_bias ), layer)
303
- aₖ ₊₁ᵅ⁻, pullbackᵅ⁻ = Zygote. pullback (preactivation (layer), aₖ⁻)
303
+ zₖ ₊₁ᵅ⁻, pullbackᵅ⁻ = Zygote. pullback (preactivation (layer), aₖ⁻)
304
304
reset! ()
305
305
modify_layer! (Val (:keep_positive ), layer)
306
- aₖ ₊₁ᵅ⁺, pullbackᵅ⁺ = Zygote. pullback (preactivation (layer), aₖ⁺)
306
+ zₖ ₊₁ᵅ⁺, pullbackᵅ⁺ = Zygote. pullback (preactivation (layer), aₖ⁺)
307
307
# evaluate pullbacks
308
- yᵅ = Rₖ₊₁ ./ modify_denominator (rule, aₖ ₊₁ᵅ⁺ + aₖ ₊₁ᵅ⁻)
309
- Rₖ .= rule. α .* aₖ⁺ .* only (pullbackᵅ⁺ (yᵅ ))
308
+ sₖ₊₁ᵅ = Rₖ₊₁ ./ modify_denominator (rule, zₖ ₊₁ᵅ⁺ + zₖ ₊₁ᵅ⁻)
309
+ Rₖ .= rule. α .* aₖ⁺ .* only (pullbackᵅ⁺ (sₖ₊₁ᵅ ))
310
310
reset! ()
311
311
modify_layer! (Val (:keep_negative_zero_bias ), layer) # re-modify mutated pullback
312
- Rₖ .+ = rule. α .* aₖ⁻ .* only (pullbackᵅ⁻ (yᵅ ))
312
+ Rₖ .+ = rule. α .* aₖ⁻ .* only (pullbackᵅ⁻ (sₖ₊₁ᵅ ))
313
313
reset! ()
314
314
315
315
# β: Negative contributions
316
316
modify_layer! (Val (:keep_positive_zero_bias ), layer)
317
- aₖ ₊₁ᵝ⁻, pullbackᵝ⁻ = Zygote. pullback (preactivation (layer), aₖ⁻) #
317
+ zₖ ₊₁ᵝ⁻, pullbackᵝ⁻ = Zygote. pullback (preactivation (layer), aₖ⁻) #
318
318
reset! ()
319
319
modify_layer! (Val (:keep_negative ), layer)
320
- aₖ ₊₁ᵝ⁺, pullbackᵝ⁺ = Zygote. pullback (preactivation (layer), aₖ⁺)
320
+ zₖ ₊₁ᵝ⁺, pullbackᵝ⁺ = Zygote. pullback (preactivation (layer), aₖ⁺)
321
321
# evaluate pullbacks
322
- yᵝ = Rₖ₊₁ ./ modify_denominator (rule, aₖ ₊₁ᵝ⁺ + aₖ ₊₁ᵝ⁻)
323
- Rₖ .- = rule. β .* aₖ⁺ .* only (pullbackᵝ⁺ (yᵝ ))
322
+ sₖ₊₁ᵝ = Rₖ₊₁ ./ modify_denominator (rule, zₖ ₊₁ᵝ⁺ + zₖ ₊₁ᵝ⁻)
323
+ Rₖ .- = rule. β .* aₖ⁺ .* only (pullbackᵝ⁺ (sₖ₊₁ᵝ ))
324
324
reset! ()
325
325
modify_layer! (Val (:keep_positive_zero_bias ), layer) # re-modify mutated pullback
326
- Rₖ .- = rule. β .* aₖ⁻ .* only (pullbackᵝ⁻ (yᵝ ))
326
+ Rₖ .- = rule. β .* aₖ⁻ .* only (pullbackᵝ⁻ (sₖ₊₁ᵝ ))
327
327
reset! ()
328
328
return nothing
329
329
end
@@ -350,17 +350,17 @@ function lrp!(Rₖ, rule::ZPlusRule, layer::L, aₖ, Rₖ₊₁) where {L}
350
350
351
351
# Linearize around positive & negative activations (aₖ⁺, aₖ⁻)
352
352
modify_layer! (Val (:keep_positive ), layer)
353
- aₖ ₊₁⁺, pullback⁺ = Zygote. pullback (layer, aₖ⁺)
353
+ zₖ ₊₁⁺, pullback⁺ = Zygote. pullback (layer, aₖ⁺)
354
354
reset! ()
355
355
modify_layer! (Val (:keep_negative_zero_bias ), layer)
356
- aₖ ₊₁⁻, pullback⁻ = Zygote. pullback (layer, aₖ⁻)
356
+ zₖ ₊₁⁻, pullback⁻ = Zygote. pullback (layer, aₖ⁻)
357
357
358
358
# Evaluate pullbacks
359
- y = Rₖ₊₁ ./ modify_denominator (rule, aₖ ₊₁⁺ + aₖ ₊₁⁻)
360
- Rₖ .= aₖ⁻ .* only (pullback⁻ (y ))
359
+ sₖ₊₁ = Rₖ₊₁ ./ modify_denominator (rule, zₖ ₊₁⁺ + zₖ ₊₁⁻)
360
+ Rₖ .= aₖ⁻ .* only (pullback⁻ (sₖ₊₁ ))
361
361
reset! ()
362
362
modify_layer! (Val (:keep_positive ), layer) # re-modify mutated pullback
363
- Rₖ .+ = aₖ⁺ .* only (pullback⁺ (y ))
363
+ Rₖ .+ = aₖ⁺ .* only (pullback⁺ (sₖ₊₁ ))
364
364
reset! ()
365
365
return nothing
366
366
end
@@ -378,8 +378,8 @@ for R in (ZeroRule, EpsilonRule, GammaRule, WSquareRule, FlatRule)
378
378
reset! = get_layer_resetter (rule, layer)
379
379
modify_layer! (rule, layer)
380
380
ãₖ = modify_input (rule, aₖ)
381
- ãₖ ₊₁ = modify_denominator (rule, preactivation (layer, ãₖ))
382
- @tullio Rₖ[j, b] = ãₖ[j, b] * layer. weight[k, j] * Rₖ₊₁[k, b] / ãₖ ₊₁[k, b]
381
+ zₖ ₊₁ = modify_denominator (rule, preactivation (layer, ãₖ))
382
+ @tullio Rₖ[j, b] = ãₖ[j, b] * layer. weight[k, j] * Rₖ₊₁[k, b] / zₖ ₊₁[k, b]
383
383
reset! ()
384
384
return nothing
385
385
end
0 commit comments