@@ -12,7 +12,7 @@ const LRP_DEFAULT_BETA = 1.0f0
12
12
function lrp! (Rᵏ, rule:: AbstractLRPRule , layer, modified_layer, aᵏ, Rᵏ⁺¹)
13
13
layer = isnothing (modified_layer) ? layer : modified_layer
14
14
ãᵏ = modify_input (rule, aᵏ)
15
- z, back = Zygote . pullback (layer, ãᵏ)
15
+ z, back = pullback (layer, ãᵏ)
16
16
s = Rᵏ⁺¹ ./ modify_denominator (rule, z)
17
17
c = only (back (s))
18
18
Rᵏ .= ãᵏ .* c
@@ -338,9 +338,9 @@ function lrp!(Rᵏ, rule::ZBoxRule, layer, modified_layers, aᵏ, Rᵏ⁺¹)
338
338
l = zbox_input (aᵏ, rule. low)
339
339
h = zbox_input (aᵏ, rule. high)
340
340
341
- z, back = Zygote . pullback (layer, aᵏ)
342
- z⁺, back⁺ = Zygote . pullback (modified_layers. layer⁺, l)
343
- z⁻, back⁻ = Zygote . pullback (modified_layers. layer⁻, h)
341
+ z, back = pullback (layer, aᵏ)
342
+ z⁺, back⁺ = pullback (modified_layers. layer⁺, l)
343
+ z⁻, back⁻ = pullback (modified_layers. layer⁻, h)
344
344
345
345
s = Rᵏ⁺¹ ./ modify_denominator (rule, z - z⁺ - z⁻)
346
346
c = only (back (s))
@@ -402,8 +402,8 @@ function lrp!(Rᵏ, rule::AlphaBetaRule, _layer, modified_layers, aᵏ, Rᵏ⁺
402
402
aᵏ⁺ = keep_positive (aᵏ)
403
403
aᵏ⁻ = keep_negative (aᵏ)
404
404
405
- zᵅ⁺, back⁺ = Zygote . pullback (modified_layers. layerᵅ⁺, aᵏ⁺)
406
- zᵅ⁻, back⁻ = Zygote . pullback (modified_layers. layerᵅ⁻, aᵏ⁻)
405
+ zᵅ⁺, back⁺ = pullback (modified_layers. layerᵅ⁺, aᵏ⁺)
406
+ zᵅ⁻, back⁻ = pullback (modified_layers. layerᵅ⁻, aᵏ⁻)
407
407
# No need to linearize again: Wᵝ⁺ = Wᵅ⁺ and Wᵝ⁻ = Wᵅ⁻
408
408
zᵝ⁺ = modified_layers. layerᵝ⁺ (aᵏ⁻)
409
409
zᵝ⁻ = modified_layers. layerᵝ⁻ (aᵏ⁺)
@@ -451,8 +451,8 @@ function lrp!(Rᵏ, rule::ZPlusRule, _layer, modified_layers, aᵏ, Rᵏ⁺¹)
451
451
aᵏ⁺ = keep_positive (aᵏ)
452
452
aᵏ⁻ = keep_negative (aᵏ)
453
453
454
- z⁺, back⁺ = Zygote . pullback (modified_layers. layer⁺, aᵏ⁺)
455
- z⁻, back⁻ = Zygote . pullback (modified_layers. layer⁻, aᵏ⁻)
454
+ z⁺, back⁺ = pullback (modified_layers. layer⁺, aᵏ⁺)
455
+ z⁻, back⁻ = pullback (modified_layers. layer⁻, aᵏ⁻)
456
456
457
457
s = Rᵏ⁺¹ ./ modify_denominator (rule, z⁺ + z⁻)
458
458
c⁺ = only (back⁺ (s))
@@ -504,8 +504,8 @@ function lrp!(Rᵏ, rule::GeneralizedGammaRule, layer, modified_layers, aᵏ, R
504
504
aᵏ⁺ = keep_positive (aᵏ)
505
505
aᵏ⁻ = keep_negative (aᵏ)
506
506
507
- zˡ⁺, back⁺ = Zygote . pullback (modified_layers. layerˡ⁺, aᵏ⁺)
508
- zˡ⁻, back⁻ = Zygote . pullback (modified_layers. layerˡ⁻, aᵏ⁻)
507
+ zˡ⁺, back⁺ = pullback (modified_layers. layerˡ⁺, aᵏ⁺)
508
+ zˡ⁻, back⁻ = pullback (modified_layers. layerˡ⁻, aᵏ⁻)
509
509
# No need to linearize again: Wˡ⁺ = Wʳ⁺ and Wˡ⁻ = Wʳ⁻
510
510
zʳ⁺ = modified_layers. layerʳ⁺ (aᵏ⁻)
511
511
zʳ⁻ = modified_layers. layerʳ⁻ (aᵏ⁺)
0 commit comments