@@ -48,40 +48,40 @@ LRP(model::Chain, c::Composite; kwargs...) = LRP(model, lrp_rules(model, c); kwa
48
48
49
49
get_activations (model, input) = [input, Flux. activations (model, input)... ]
50
50
51
+ function mask_output_neuron! (Rᴺ, aᴺ, ns:: AbstractNeuronSelector )
52
+ fill! (Rᴺ, 0 )
53
+ idx = ns (aᴺ)
54
+ Rᴺ[idx] .= 1
55
+ return Rᴺ
56
+ end
57
+
51
58
# Call to the LRP analyzer
52
59
function (lrp:: LRP )(
53
60
input:: AbstractArray{T} , ns:: AbstractNeuronSelector ; layerwise_relevances= false
54
61
) where {T}
55
- acts = get_activations (lrp. model, input) # compute aᵏ for all layers k
56
- rels = similar .(acts) # allocate Rᵏ for all layers k
57
- mask_output_neuron! (rels[end ], acts[end ], ns) # compute Rᵏ⁺¹ of output layer
58
-
59
- # Apply LRP rules in backward-pass, inplace-updating relevances `rels[i]`
60
- for i in length (lrp. model): - 1 : 1
61
- lrp! (
62
- rels[i],
63
- lrp. rules[i],
64
- lrp. model[i],
65
- lrp. modified_layers[i],
66
- acts[i],
67
- rels[i + 1 ],
68
- )
62
+ as = get_activations (lrp. model, input) # compute activations aᵏ for all layers k
63
+ Rs = similar .(as) # allocate relevances Rᵏ for all layers k
64
+ mask_output_neuron! (Rs[end ], as[end ], ns) # compute relevance Rᴺ of output layer N
65
+
66
+ # Apply LRP rules in backward-pass, inplace-updating relevances `Rs[k]` = Rᵏ
67
+ for k in length (lrp. model): - 1 : 1
68
+ lrp! (Rs[k], lrp. rules[k], lrp. model[k], lrp. modified_layers[k], as[k], Rs[k + 1 ])
69
69
end
70
- extras = layerwise_relevances ? (layerwise_relevances= rels,) : nothing
71
70
72
- return Explanation (first (rels), last (acts), ns (last (acts)), :LRP , extras)
71
+ extras = layerwise_relevances ? (layerwise_relevances= Rs,) : nothing
72
+ return Explanation (first (Rs), last (as), ns (last (as)), :LRP , extras)
73
73
end
74
74
75
75
function lrp! (Rᵏ, rules:: ChainTuple , chain:: Chain , modified_chain:: ChainTuple , aᵏ, Rᵏ⁺¹)
76
- acts = get_activations (chain, aᵏ)
77
- rels = similar .(acts )
78
- last (rels ) .= Rᵏ⁺¹
76
+ as = get_activations (chain, aᵏ)
77
+ Rs = similar .(as )
78
+ last (Rs ) .= Rᵏ⁺¹
79
79
80
- # Apply LRP rules in backward-pass, inplace-updating relevances `rels [i]`
80
+ # Apply LRP rules in backward-pass, inplace-updating relevances `Rs [i]`
81
81
for i in length (chain): - 1 : 1
82
- lrp! (rels [i], rules[i], chain[i], modified_chain[i], acts [i], rels [i + 1 ])
82
+ lrp! (Rs [i], rules[i], chain[i], modified_chain[i], as [i], Rs [i + 1 ])
83
83
end
84
- return Rᵏ .= first (rels )
84
+ return Rᵏ .= first (Rs )
85
85
end
86
86
87
87
function lrp! (
0 commit comments