Skip to content

Commit f5970b6

Browse files
authored
Fast FlatRule (#96)
1 parent bf4c581 commit f5970b6

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/lrp/rules.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,9 @@ and all bias terms set to zero.
213213
# Definition
214214
Propagates relevance ``R^{k+1}`` at layer output to ``R^k`` at layer input according to
215215
```math
216-
R_j^k = \\sum_i\\frac{1}{\\sum_l 1} R_i^{k+1} = \\frac{1}{n}\\sum_i R_i^{k+1}
216+
R_j^k = \\sum_i\\frac{1}{\\sum_l 1} R_i^{k+1} = \\sum_i\\frac{1}{n_i} R_i^{k+1}
217217
```
218-
where ``n`` is the number of input neurons connected to the output neuron at index ``i``.
218+
where ``n_i`` is the number of input neurons connected to the output neuron at index ``i``.
219219
220220
# References
221221
- $REF_LAPUSCHKIN_CLEVER_HANS
@@ -434,7 +434,7 @@ for R in (ZeroRule, EpsilonRule)
434434
end
435435

436436
# Fast implementation for Dense layer using Tullio.jl's einsum notation:
437-
for R in (ZeroRule, EpsilonRule, GammaRule, WSquareRule, FlatRule)
437+
for R in (ZeroRule, EpsilonRule, GammaRule, WSquareRule)
438438
@eval function lrp!(Rₖ, rule::$R, layer::Dense, aₖ, Rₖ₊₁)
439439
reset! = get_layer_resetter(rule, layer)
440440
modify_layer!(rule, layer)
@@ -445,3 +445,10 @@ for R in (ZeroRule, EpsilonRule, GammaRule, WSquareRule, FlatRule)
445445
return nothing
446446
end
447447
end
448+
function lrp!(Rₖ, ::FlatRule, layer::Dense, aₖ, Rₖ₊₁)
449+
n = size(Rₖ, 1) # number of input neurons connected to each output neuron
450+
for i in axes(Rₖ, 2) # samples in batch
451+
fill!(view(Rₖ, :, i), sum(view(Rₖ₊₁, :, i)) / n)
452+
end
453+
return nothing
454+
end

0 commit comments

Comments
 (0)