Skip to content

Commit 1d0fbf6

Browse files
authored
Fix bug in ZBoxRule (#77)
* Refactor masking mechanism * Add kwarg `ignore_bias` to `modify_layer!` * Add `ignore_bias` tests * Refactor `modify_layer!` * Fix bug in `ZBoxRule` and update references
1 parent e6061e5 commit 1d0fbf6

15 files changed

+95
-52
lines changed

src/lrp_rules.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,28 @@ otherwise throw an `ArgumentError`.
4747
check_compat(rule, layer) = require_weight_and_bias(rule, layer)
4848

4949
"""
50-
modify_layer!(rule, layer)
50+
modify_layer!(rule, layer; ignore_bias=false)
5151
5252
In-place modify layer parameters by calling `modify_param!` before computing relevance
5353
propagation.
5454
5555
## Note
5656
When implementing a custom `modify_layer!` function, `modify_param!` will not be called.
5757
"""
58-
function modify_layer!(rule::R, layer::L) where {R,L}
59-
if has_weight_and_bias(layer)
60-
modify_param!(rule, layer.weight)
61-
modify_bias!(rule, layer.bias)
62-
end
58+
function modify_layer!(rule::R, layer::L; ignore_bias=false) where {R,L}
59+
!has_weight_and_bias(layer) && return nothing # skip all
60+
modify_weight!(rule, layer.weight)
61+
62+
# Checks that skip bias modification:
63+
ignore_bias && return nothing
64+
isa(layer.bias, Flux.Zeros) && return nothing # skip if bias=Flux.Zeros (Flux <= v0.12)
65+
isa(layer.bias, Bool) && !layer.bias && return nothing # skip if bias=false (Flux >= v0.13)
66+
67+
modify_bias!(rule, layer.bias)
6368
return nothing
6469
end
70+
modify_weight!(rule::R, W) where {R} = modify_param!(rule, W)
71+
modify_bias!(rule::R, b) where {R} = modify_param!(rule, b)
6572

6673
"""
6774
modify_param!(rule, W)
@@ -71,17 +78,15 @@ Inplace-modify parameters before computing the relevance.
7178
"""
7279
modify_param!(rule, param) = nothing # general fallback
7380

74-
# Useful presets:
75-
modify_param!(::Val{:mask_positive}, p) = p .= max.(zero(eltype(p)), p)
76-
modify_param!(::Val{:mask_negative}, p) = p .= min.(zero(eltype(p)), p)
81+
# Useful presets that allow us to work around bias-free layers:
82+
modify_param!(::Val{:keep_positive}, p) = keep_positive!(p)
83+
modify_param!(::Val{:keep_negative}, p) = keep_negative!(p)
7784

78-
# Internal wrapper functions for bias-free layers.
79-
modify_bias!(rule::R, b) where {R} = modify_param!(rule, b)
80-
modify_bias!(rule, b::Flux.Zeros) = nothing # skip if bias=Flux.Zeros (Flux <= v0.12)
81-
function modify_bias!(rule, b::Bool) # skip if bias=false (Flux >= v0.13)
82-
@assert b == false
83-
return nothing
84-
end
85+
modify_weight!(::Val{:keep_positive_zero_bias}, W) = keep_positive!(W)
86+
modify_bias!(::Val{:keep_positive_zero_bias}, b) = fill!(b, zero(eltype(b)))
87+
88+
modify_weight!(::Val{:keep_negative_zero_bias}, W) = keep_negative!(W)
89+
modify_bias!(::Val{:keep_negative_zero_bias}, b) = fill!(b, zero(eltype(b)))
8590

8691
# Internal function that resets parameters by capturing them in a closure.
8792
# Returns a function `reset!` that resets the parameters to their original state when called.
@@ -226,12 +231,12 @@ function lrp!(Rₖ, rule::ZBoxRule, layer::L, aₖ, Rₖ₊₁) where {L}
226231
aₖ₊₁, pullback = Zygote.pullback(layer, aₖ)
227232

228233
# Compute pullback for W⁺, b⁺
229-
modify_layer!(Val{:mask_positive}, layer)
234+
modify_layer!(Val(:keep_positive), layer)
230235
aₖ₊₁⁺, pullback⁺ = Zygote.pullback(layer, l)
231236
reset!()
232237

233238
# Compute pullback for W⁻, b⁻
234-
modify_layer!(Val{:mask_negative}, layer)
239+
modify_layer!(Val(:keep_negative), layer)
235240
aₖ₊₁⁻, pullback⁻ = Zygote.pullback(layer, h)
236241
reset!()
237242

@@ -246,12 +251,7 @@ function zbox_input(in::AbstractArray{T}, A::AbstractArray) where {T}
246251
return convert.(T, A)
247252
end
248253

249-
# Special cases for rules that don't modify params for extra performance:
250-
for R in (ZeroRule, EpsilonRule)
251-
@eval get_layer_resetter(::$R, l) = Returns(nothing)
252-
@eval lrp!(Rₖ, ::$R, ::DropoutLayer, aₖ, Rₖ₊₁) = (Rₖ .= Rₖ₊₁)
253-
@eval lrp!(Rₖ, ::$R, ::ReshapingLayer, aₖ, Rₖ₊₁) = (Rₖ .= reshape(Rₖ₊₁, size(aₖ)))
254-
end
254+
255255

256256
# Special cases for rules that don't modify params for extra performance:
257257
for R in (ZeroRule, EpsilonRule)

src/utils.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,17 @@ julia> ones_like(x)
7575
ones_like(x::AbstractArray) = ones(eltype(x), size(x))
7676
ones_like(x::Number) = oneunit(x)
7777

78+
function keep_positive!(x::AbstractArray{T}) where {T}
79+
x[x .< 0] .= zero(T)
80+
return x
81+
end
82+
function keep_negative!(x::AbstractArray{T}) where {T}
83+
x[x .> 0] .= zero(T)
84+
return x
85+
end
86+
keep_positive(x) = keep_positive!(deepcopy(x))
87+
keep_negative(x) = keep_negative!(deepcopy(x))
88+
7889
# Utils for printing model check summary using PrettyTable.jl
7990
_print_name(layer) = "$layer"
8091
_print_name(layer::Parallel) = "Parallel(...)"

test/references/heatmaps/vgg11_LRPCustom.txt

Lines changed: 15 additions & 15 deletions
Large diffs are not rendered by default.
-1.34 KB
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
-1.34 KB
Binary file not shown.

0 commit comments

Comments
 (0)