Skip to content

Commit e4092d4

Browse files
authored
Remove modify_layer (#29)
* Remove `modify_layer` * Drop uncommon Conv layers as they currently don't apply `modify_params` * Rename `get`/`set_weights` to `get`/`set_params`
1 parent 62c908d commit e4092d4

File tree

5 files changed

+29
-52
lines changed

5 files changed

+29
-52
lines changed

docs/literate/example.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,8 @@ heatmap(expl)
8989
# The rule has to be of type `AbstractLRPRule`.
9090
struct MyCustomLRPRule <: AbstractLRPRule end
9191

92-
# It is then possible to dispatch on the utility functions [`modify_layer`](@ref),
93-
# [`modify_params`](@ref) and [`modify_denominator`](@ref) with our rule type
94-
# `MyCustomLRPRule` to define custom rules without writing boilerplate code.
92+
# It is then possible to dispatch on the utility functions [`modify_params`](@ref) and [`modify_denominator`](@ref)
93+
# with our rule type `MyCustomLRPRule` to define custom rules without writing boilerplate code.
9594
function modify_params(::MyCustomLRPRule, W, b)
9695
ρW = W + 0.1 * relu.(W)
9796
return ρW, b

docs/src/api.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ ZBoxRule
2525
## Custom rules
2626
These utilities can be used to define custom rules without writing boilerplate code:
2727
```@docs
28-
modify_layer
2928
modify_params
3029
modify_denominator
3130
```

src/ExplainabilityMethods.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ export LRP, LRPZero, LRPEpsilon, LRPGamma
3131
export AbstractLRPRule
3232
export LRP_CONFIG
3333
export ZeroRule, EpsilonRule, GammaRule, ZBoxRule
34-
export modify_layer, modify_params, modify_denominator
34+
export modify_params, modify_denominator
3535
export check_model
3636

3737
# heatmapping

src/flux.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
## Group layers by type:
2-
const ConvLayer = Union{Conv,DepthwiseConv,ConvTranspose,CrossCor}
2+
const ConvLayer = Union{Conv} # TODO: DepthwiseConv, ConvTranspose, CrossCor
33
const DropoutLayer = Union{Dropout,typeof(Flux.dropout),AlphaDropout}
44
const ReshapingLayer = Union{typeof(Flux.flatten)}
55
# Pooling layers
@@ -62,7 +62,7 @@ function strip_softmax(model::Chain)
6262
end
6363

6464
# helper function to work around Flux.Zeros
65-
function get_weights(layer)
65+
function get_params(layer)
6666
W = layer.weight
6767
b = layer.bias
6868
if typeof(b) <: Flux.Zeros
@@ -72,9 +72,9 @@ function get_weights(layer)
7272
end
7373

7474
"""
75-
set_weights(layer, W, b)
75+
set_params(layer, W, b)
7676
7777
Duplicate layer using weights W, b.
7878
"""
79-
set_weights(l::Conv, W, b) = Conv(l.σ, W, b, l.stride, l.pad, l.dilation, l.groups)
80-
set_weights(l::Dense, W, b) = Dense(W, b, l.σ)
79+
set_params(l::Conv, W, b) = Conv(l.σ, W, b, l.stride, l.pad, l.dilation, l.groups)
80+
set_params(l::Dense, W, b) = Dense(W, b, l.σ)

src/lrp_rules.jl

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -18,50 +18,44 @@
1818
abstract type AbstractLRPRule end
1919

2020
# This is the generic relevance propagation rule which is used for the 0, γ and ϵ rules.
21-
# It can be extended for new rules via `modify_denominator` and `modify_layer`,
22-
# which in turn uses `modify_params`.
21+
# It can be extended for new rules via `modify_denominator` and `modify_params`.
22+
# Since it uses autodiff, it is used as a fallback for layer types without custom implementation.
2323
function (rule::AbstractLRPRule)(layer, aₖ, Rₖ₊₁)
24-
layerᵨ = modify_layer(rule, layer)
24+
layerᵨ = _modify_layer(rule, layer)
2525
function fwpass(a)
2626
z = layerᵨ(a)
2727
s = Zygote.dropgrad(Rₖ₊₁ ./ modify_denominator(rule, z))
2828
return z s
2929
end
30-
c = gradient(fwpass, aₖ)[1]
31-
Rₖ = aₖ .* c
32-
return Rₖ
30+
return aₖ .* gradient(fwpass, aₖ)[1] # Rₖ
3331
end
3432

3533
# Special cases are dispatched on layer type:
36-
(rule::AbstractLRPRule)(::DropoutLayer, aₖ, Rₖ₊₁) = Rₖ₊₁
37-
(rule::AbstractLRPRule)(::ReshapingLayer, aₖ, Rₖ₊₁) = reshape(Rₖ₊₁, size(aₖ))
34+
(::AbstractLRPRule)(::DropoutLayer, aₖ, Rₖ₊₁) = Rₖ₊₁
35+
(::AbstractLRPRule)(::ReshapingLayer, aₖ, Rₖ₊₁) = reshape(Rₖ₊₁, size(aₖ))
3836

37+
# To implement new rules, we can define two custom functions `modify_params` and `modify_denominator`.
38+
# If this isn't done, the following fallbacks are used by default:
3939
"""
40-
modify_layer(rule, layer)
41-
42-
Applies `modify_params` to layer if it has parameters
43-
"""
44-
modify_layer(::AbstractLRPRule, l) = l # skip layers without params
45-
function modify_layer(rule::AbstractLRPRule, l::Union{Dense,Conv})
46-
W, b = get_weights(l)
47-
ρW, ρb = modify_params(rule, W, b)
48-
return set_weights(l, ρW, ρb)
49-
end
50-
51-
"""
52-
modify_params!(rule, W, b)
40+
modify_params(rule, W, b)
5341
5442
Function that modifies weights and biases before applying relevance propagation.
5543
"""
5644
modify_params(::AbstractLRPRule, W, b) = (W, b) # general fallback
5745

5846
"""
59-
modify_denominator!(d, rule)
47+
modify_denominator(rule, d)
6048
6149
Function that modifies zₖ on the forward pass, e.g. for numerical stability.
6250
"""
6351
modify_denominator(::AbstractLRPRule, d) = stabilize_denom(d; eps=1.0f-9) # general fallback
6452

53+
# This helper function applies `modify_params`:
54+
_modify_layer(::AbstractLRPRule, layer) = layer # skip layers without modify_params
55+
function _modify_layer(rule::AbstractLRPRule, layer::Union{Dense,Conv})
56+
return set_params(layer, modify_params(rule, get_params(layer)...)...)
57+
end
58+
6559
"""
6660
ZeroRule()
6761
@@ -111,11 +105,11 @@ struct ZBoxRule <: AbstractLRPRule end
111105

112106
# The ZBoxRule requires its own implementation of relevance propagation.
113107
function (rule::ZBoxRule)(layer::Union{Dense,Conv}, aₖ, Rₖ₊₁)
114-
layer, layer⁺, layer⁻ = modify_layer(rule, layer)
108+
W, b = get_params(layer)
109+
l, h = fill.(extrema(aₖ), (size(aₖ),))
115110

116-
onemat = ones(eltype(aₖ), size(aₖ))
117-
l = onemat * minimum(aₖ)
118-
h = onemat * maximum(aₖ)
111+
layer⁺ = set_params(layer, max.(0, W), max.(0, b)) # W⁺, b⁺
112+
layer⁻ = set_params(layer, min.(0, W), min.(0, b)) # W⁻, b⁻
119113

120114
# Forward pass
121115
function fwpass(a, l, h)
@@ -128,20 +122,5 @@ function (rule::ZBoxRule)(layer::Union{Dense,Conv}, aₖ, Rₖ₊₁)
128122
return z s
129123
end
130124
c, cₗ, cₕ = gradient(fwpass, aₖ, l, h) # w.r.t. three inputs
131-
132-
# Backward pass
133-
Rₖ = aₖ .* c + l .* cₗ + h .* cₕ
134-
return Rₖ
135-
end
136-
137-
function modify_layer(::ZBoxRule, l::Union{Dense,Conv})
138-
W, b = get_weights(l)
139-
W⁻ = min.(0, W)
140-
W⁺ = max.(0, W)
141-
b⁻ = min.(0, b)
142-
b⁺ = max.(0, b)
143-
144-
l⁺ = set_weights(l, W⁺, b⁺)
145-
l⁻ = set_weights(l, W⁻, b⁻)
146-
return l, l⁺, l⁻
125+
return aₖ .* c + l .* cₗ + h .* cₕ # Rₖ from backward pass
147126
end

0 commit comments

Comments
 (0)