Skip to content

Commit 3aa10ba

Browse files
authored
Fix neuron selection on GPUs (#140)
* Fix neuron selection * Rename `rels` to `Rs` * Rename `acts` to `as`
1 parent b6658cb commit 3aa10ba

File tree

3 files changed

+34
-39
lines changed

3 files changed

+34
-39
lines changed

docs/src/lrp/developer.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,21 +148,21 @@ For a detailed description of the layer modification mechanism, refer to the sec
148148

149149
## Forward and reverse pass
150150
When calling an `LRP` analyzer, a forward pass through the model is performed,
151-
saving the activations $aᵏ$ for all layers $k$ in a vector called `acts`.
151+
saving the activations $aᵏ$ for all layers $k$ in a vector called `as`.
152152
This vector of activations is then used to pre-allocate the relevances $R^k$
153-
for all layers in a vector called `rels`.
153+
for all layers in a vector called `Rs`.
154154
This is possible since for any layer $k$, $a^k$ and $R^k$ have the same shape.
155-
Finally, the last array of relevances $R^N$ in `rels` is set to zeros,
155+
Finally, the last array of relevances $R^N$ in `Rs` is set to zeros,
156156
except for the specified output neuron, which is set to one.
157157

158158
We can now run the reverse pass, iterating backwards over the layers in the model
159-
and writing relevances $R^k$ into the pre-allocated array `rels`:
159+
and writing relevances $R^k$ into the pre-allocated array `Rs`:
160160

161161
```julia
162162
for k in length(model):-1:1
163163
# └─ loop over layers in reverse
164-
lrp!(rels[k], rules[k], layers[k], modified_layers[i], acts[k], rels[k+1])
165-
# └─ Rᵏ: modified in-place └─ aᵏ └─ Rᵏ⁺¹
164+
lrp!(Rs[k], rules[k], layers[k], modified_layers[i], as[k], Rs[k+1])
165+
# └─ Rᵏ: modified in-place └─ aᵏ └─ Rᵏ⁺¹
166166
end
167167
```
168168

@@ -185,7 +185,7 @@ and the output relevance `Rᵏ⁺¹`.
185185
The exclamation point in the function name `lrp!` is a
186186
[naming convention](https://docs.julialang.org/en/v1/manual/style-guide/#bang-convention)
187187
in Julia to denote functions that modify their arguments --
188-
in this case the first argument `rels[k]`, which corresponds to $R^k$.
188+
in this case the first argument `Rs[k]`, which corresponds to $R^k$.
189189

190190
### Rule calls
191191
As discussed in [*The AD fallback*](@ref lrp-dev-ad-fallback),

src/lrp/lrp.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -48,40 +48,40 @@ LRP(model::Chain, c::Composite; kwargs...) = LRP(model, lrp_rules(model, c); kwa
4848

4949
get_activations(model, input) = [input, Flux.activations(model, input)...]
5050

51+
function mask_output_neuron!(Rᴺ, aᴺ, ns::AbstractNeuronSelector)
52+
fill!(Rᴺ, 0)
53+
idx = ns(aᴺ)
54+
Rᴺ[idx] .= 1
55+
return Rᴺ
56+
end
57+
5158
# Call to the LRP analyzer
5259
function (lrp::LRP)(
5360
input::AbstractArray{T}, ns::AbstractNeuronSelector; layerwise_relevances=false
5461
) where {T}
55-
acts = get_activations(lrp.model, input) # compute aᵏ for all layers k
56-
rels = similar.(acts) # allocate Rᵏ for all layers k
57-
mask_output_neuron!(rels[end], acts[end], ns) # compute Rᵏ⁺¹ of output layer
58-
59-
# Apply LRP rules in backward-pass, inplace-updating relevances `rels[i]`
60-
for i in length(lrp.model):-1:1
61-
lrp!(
62-
rels[i],
63-
lrp.rules[i],
64-
lrp.model[i],
65-
lrp.modified_layers[i],
66-
acts[i],
67-
rels[i + 1],
68-
)
62+
as = get_activations(lrp.model, input) # compute activations aᵏ for all layers k
63+
Rs = similar.(as) # allocate relevances Rᵏ for all layers k
64+
mask_output_neuron!(Rs[end], as[end], ns) # compute relevance Rᴺ of output layer N
65+
66+
# Apply LRP rules in backward-pass, inplace-updating relevances `Rs[k]` = Rᵏ
67+
for k in length(lrp.model):-1:1
68+
lrp!(Rs[k], lrp.rules[k], lrp.model[k], lrp.modified_layers[k], as[k], Rs[k + 1])
6969
end
70-
extras = layerwise_relevances ? (layerwise_relevances=rels,) : nothing
7170

72-
return Explanation(first(rels), last(acts), ns(last(acts)), :LRP, extras)
71+
extras = layerwise_relevances ? (layerwise_relevances=Rs,) : nothing
72+
return Explanation(first(Rs), last(as), ns(last(as)), :LRP, extras)
7373
end
7474

7575
function lrp!(Rᵏ, rules::ChainTuple, chain::Chain, modified_chain::ChainTuple, aᵏ, Rᵏ⁺¹)
76-
acts = get_activations(chain, aᵏ)
77-
rels = similar.(acts)
78-
last(rels) .= Rᵏ⁺¹
76+
as = get_activations(chain, aᵏ)
77+
Rs = similar.(as)
78+
last(Rs) .= Rᵏ⁺¹
7979

80-
# Apply LRP rules in backward-pass, inplace-updating relevances `rels[i]`
80+
# Apply LRP rules in backward-pass, inplace-updating relevances `Rs[i]`
8181
for i in length(chain):-1:1
82-
lrp!(rels[i], rules[i], chain[i], modified_chain[i], acts[i], rels[i + 1])
82+
lrp!(Rs[i], rules[i], chain[i], modified_chain[i], as[i], Rs[i + 1])
8383
end
84-
return Rᵏ .= first(rels)
84+
return Rᵏ .= first(Rs)
8585
end
8686

8787
function lrp!(

src/neuron_selection.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,5 @@
11
abstract type AbstractNeuronSelector end
22

3-
function mask_output_neuron!(R, a, ns::AbstractNeuronSelector)
4-
fill!(R, 0)
5-
idx = ns(a)
6-
R[idx] .= 1
7-
return R
8-
end
9-
103
"""
114
MaxActivationSelector()
125
@@ -15,7 +8,7 @@ Neuron selector that picks the output neuron with the highest activation.
158
struct MaxActivationSelector <: AbstractNeuronSelector end
169
function (::MaxActivationSelector)(out::AbstractArray{T,N}) where {T,N}
1710
N < 2 && throw(BATCHDIM_MISSING)
18-
return Vector{CartesianIndex{N}}([argmax(out; dims=1:(N - 1))...])
11+
return vec(argmax(out; dims=1:(N - 1)))
1912
end
2013

2114
"""
@@ -28,11 +21,13 @@ struct IndexSelector{I} <: AbstractNeuronSelector
2821
end
2922
function (s::IndexSelector{<:Integer})(out::AbstractArray{T,N}) where {T,N}
3023
N < 2 && throw(BATCHDIM_MISSING)
31-
return CartesianIndex{N}.(s.index, 1:size(out, N))
24+
batchsize = size(out, N)
25+
return [CartesianIndex{N}(s.index, b) for b in 1:batchsize]
3226
end
3327
function (s::IndexSelector{I})(out::AbstractArray{T,N}) where {I,T,N}
3428
N < 2 && throw(BATCHDIM_MISSING)
35-
return CartesianIndex{N}.(s.index..., 1:size(out, N))
29+
batchsize = size(out, N)
30+
return [CartesianIndex{N}(s.index..., b) for b in 1:batchsize]
3631
end
3732

3833
"""

0 commit comments

Comments
 (0)