@@ -46,13 +46,21 @@ LRP(model::Chain; kwargs...) = LRP(model, Composite(ZeroRule()); kwargs...)
46
46
# Construct Chain-/ParallelTuple of rules by applying composite
47
47
LRP (model:: Chain , c:: Composite ; kwargs... ) = LRP (model, lrp_rules (model, c); kwargs... )
48
48
49
- get_activations (model, input) = [ input, Flux. activations (model, input)... ]
49
+ get_activations (model, input) = ( input, Flux. activations (model, input)... )
50
50
51
- function mask_output_neuron! (Rᴺ, aᴺ, ns:: AbstractNeuronSelector )
52
- fill! (Rᴺ, 0 )
53
- idx = ns (aᴺ)
54
- Rᴺ[idx] .= 1
55
- return Rᴺ
51
+ function mask_output_neuron! (R_out, a_out, ns:: AbstractNeuronSelector )
52
+ fill! (R_out, 0 )
53
+ idx = ns (a_out)
54
+ R_out[idx] .= 1
55
+ return R_out
56
+ end
57
+
58
+ function lrp_backward_pass! (Rs, as, rules, layers, modified_layers)
59
+ # Apply LRP rules in backward-pass, inplace-updating relevances `Rs[k]` = Rᵏ
60
+ for k in length (layers): - 1 : 1
61
+ lrp! (Rs[k], rules[k], layers[k], modified_layers[k], as[k], Rs[k + 1 ])
62
+ end
63
+ return Rs
56
64
end
57
65
58
66
# Call to the LRP analyzer
@@ -63,11 +71,7 @@ function (lrp::LRP)(
63
71
Rs = similar .(as) # allocate relevances Rᵏ for all layers k
64
72
mask_output_neuron! (Rs[end ], as[end ], ns) # compute relevance Rᴺ of output layer N
65
73
66
- # Apply LRP rules in backward-pass, inplace-updating relevances `Rs[k]` = Rᵏ
67
- for k in length (lrp. model): - 1 : 1
68
- lrp! (Rs[k], lrp. rules[k], lrp. model[k], lrp. modified_layers[k], as[k], Rs[k + 1 ])
69
- end
70
-
74
+ lrp_backward_pass! (Rs, as, lrp. rules, lrp. model, lrp. modified_layers)
71
75
extras = layerwise_relevances ? (layerwise_relevances= Rs,) : nothing
72
76
return Explanation (first (Rs), last (as), ns (last (as)), :LRP , extras)
73
77
end
@@ -77,10 +81,7 @@ function lrp!(Rᵏ, rules::ChainTuple, chain::Chain, modified_chain::ChainTuple,
77
81
Rs = similar .(as)
78
82
last (Rs) .= Rᵏ⁺¹
79
83
80
- # Apply LRP rules in backward-pass, inplace-updating relevances `Rs[i]`
81
- for i in length (chain): - 1 : 1
82
- lrp! (Rs[i], rules[i], chain[i], modified_chain[i], as[i], Rs[i + 1 ])
83
- end
84
+ lrp_backward_pass! (Rs, as, rules, chain, modified_chain)
84
85
return Rᵏ .= first (Rs)
85
86
end
86
87
@@ -91,15 +92,15 @@ function lrp!(
91
92
# according to the contribution aᵏ⁺¹ᵢ of branch i to the output activation aᵏ⁺¹:
92
93
# Rᵏ⁺¹ᵢ = Rᵏ⁺¹ .* aᵏ⁺¹ᵢ ./ aᵏ⁺¹ = c .* aᵏ⁺¹ᵢ
93
94
94
- aᵏ⁺¹s = [l (aᵏ) for l in parallel. layers] # aᵏ⁺¹ᵢ for each branch i
95
- c = Rᵏ⁺¹ ./ stabilize_denom (sum (aᵏ⁺¹s ))
96
- Rᵏ⁺¹s = [c .* aᵏ⁺¹ᵢ for aᵏ⁺¹ᵢ in aᵏ⁺¹s] # Rᵏ⁺¹ᵢ for each branch i
97
- Rᵏs = [similar (aᵏ) for _ in parallel. layers] # pre-allocate output Rᵏᵢ for each branch i
95
+ aᵏ⁺¹_parallel = [layer (aᵏ) for layer in parallel. layers] # aᵏ⁺¹ᵢ for each branch i
96
+ c = Rᵏ⁺¹ ./ stabilize_denom (sum (aᵏ⁺¹_parallel ))
97
+ Rᵏ⁺¹_parallel = [c .* a for a in aᵏ⁺¹_parallel] # Rᵏ⁺¹ᵢ for each branch i
98
+ Rᵏ_parallel = [similar (aᵏ) for _ in parallel. layers] # pre-allocate output Rᵏᵢ for each branch
98
99
99
- for (Rᵏᵢ, ruleᵢ, layerᵢ, modified_layerᵢ , Rᵏ⁺¹ᵢ) in
100
- zip (Rᵏs , rules, parallel. layers, modified_parallel, Rᵏ⁺¹s )
101
- # In-place update Rᵏᵢ ( and therefore Rᵏs)
102
- lrp! (Rᵏᵢ, ruleᵢ, layerᵢ, modified_layerᵢ , aᵏ, Rᵏ⁺¹ᵢ)
100
+ for (Rᵏᵢ, rule, layer, modified_layer , Rᵏ⁺¹ᵢ) in
101
+ zip (Rᵏ_parallel , rules, parallel. layers, modified_parallel, Rᵏ⁺¹_parallel )
102
+ # In-place update Rᵏᵢ and therefore Rᵏ_parallel
103
+ lrp! (Rᵏᵢ, rule, layer, modified_layer , aᵏ, Rᵏ⁺¹ᵢ)
103
104
end
104
- return Rᵏ .= sum (Rᵏs )
105
+ return Rᵏ .= sum (Rᵏ_parallel )
105
106
end
0 commit comments