Skip to content

Commit a6e2c59

Browse files
committed
Add FlatRule and WSquareRule
1 parent e9f7d35 commit a6e2c59

File tree

3 files changed

+48
-2
lines changed

3 files changed

+48
-2
lines changed

src/ExplainableAI.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ export LRP
4242
# LRP rules
4343
export AbstractLRPRule
4444
export LRP_CONFIG
45-
export ZeroRule, EpsilonRule, GammaRule, ZBoxRule, PassRule
45+
export ZeroRule, EpsilonRule, GammaRule, WSquareRule, FlatRule, ZBoxRule, PassRule
4646
export modify_input, modify_denominator
4747
export modify_param!, modify_layer!
4848
export check_model

src/lrp_rules.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,29 @@ function modify_param!(r::GammaRule, param::AbstractArray{T}) where {T}
156156
return nothing
157157
end
158158

159+
"""
160+
WSquareRule()
161+
162+
LRP-``W^2`` rule. Commonly used on the first layer when values are unbounded.
163+
164+
# References
165+
[1]: G. Montavon et al., Explaining nonlinear classification decisions with deep Taylor decomposition
166+
"""
167+
struct WSquareRule <: AbstractLRPRule end
168+
modify_param!(::WSquareRule, p) = p .^= 2
169+
modify_input(::WSquareRule, input) = ones_like(input)
170+
171+
"""
172+
FlatRule()
173+
174+
LRP-Flat rule. Similar to the [`WSquareRule`](@ref), but with all parameters set to one.
175+
176+
# References
177+
[1]: S. Lapuschkin et al., Unmasking Clever Hans predictors and assessing what machines really learn
178+
"""
179+
struct FlatRule <: AbstractLRPRule end
180+
modify_param!(::FlatRule, p) = fill!(p, 0)
181+
modify_input(::FlatRule, input) = ones_like(input)
159182

160183
"""
161184
PassRule()
@@ -238,7 +261,7 @@ for R in (ZeroRule, EpsilonRule)
238261
end
239262

240263
# Fast implementation for Dense layer using Tullio.jl's einsum notation:
241-
for R in (ZeroRule, EpsilonRule, GammaRule)
264+
for R in (ZeroRule, EpsilonRule, GammaRule, WSquareRule, FlatRule)
242265
@eval function lrp!(Rₖ, rule::$R, layer::Dense, aₖ, Rₖ₊₁)
243266
reset! = get_layer_resetter(rule, layer)
244267
modify_layer!(rule, layer)

src/utils.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,29 @@ CartesianIndex(5, 3)
5252
"""
5353
drop_batch_index(C::CartesianIndex) = CartesianIndex(C.I[1:(end - 1)])
5454

55+
"""
56+
ones_like(x)
57+
58+
Returns array of ones of same shape and type as `x`.
59+
60+
## Example
61+
```julia-repl
62+
julia> x = rand(Float16, 2, 4, 1)
63+
2×4×1 Array{Float16, 3}:
64+
[:, :, 1] =
65+
0.2148 0.9053 0.751 0.358
66+
0.38 0.09033 0.04053 0.6543
67+
68+
julia> ones_like(x)
69+
2×4×1 Array{Float16, 3}:
70+
[:, :, 1] =
71+
1.0 1.0 1.0 1.0
72+
1.0 1.0 1.0 1.0
73+
```
74+
"""
75+
ones_like(x::AbstractArray) = ones(eltype(x), size(x))
76+
ones_like(x::Number) = oneunit(x)
77+
5578
# Utils for printing model check summary using PrettyTable.jl
5679
_print_name(layer) = "$layer"
5780
_print_name(layer::Parallel) = "Parallel(...)"

0 commit comments

Comments
 (0)