Skip to content

Commit dc4d0fc

Browse files
authored
Fix ZBoxRule (#69)
1 parent 8b37758 commit dc4d0fc

File tree

9 files changed

+43
-16
lines changed

9 files changed

+43
-16
lines changed

benchmark/benchmarks.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,16 @@ on_CI = haskey(ENV, "GITHUB_ACTIONS")
99
include("../test/vgg11.jl")
1010
vgg11 = VGG11(; pretrain=false)
1111
model = flatten_model(strip_softmax(vgg11.layers))
12-
img = rand(MersenneTwister(123), Float32, (224, 224, 3, 1))
12+
13+
T = Float32
14+
img = rand(MersenneTwister(123), T, (224, 224, 3, 1))
1315

1416
# Benchmark custom LRP composite
1517
function LRPCustom(model::Chain)
16-
return LRP(model, [ZBoxRule(), repeat([GammaRule()], length(model.layers) - 1)...])
18+
return LRP(
19+
model,
20+
[ZBoxRule(zero(T), oneunit(T)), repeat([GammaRule()], length(model.layers) - 1)...],
21+
)
1722
end
1823

1924
# Use one representative algorithm of each type
@@ -53,20 +58,20 @@ lrp!(rule::ZBoxRule, w::TestWrapper, Rₖ, aₖ, Rₖ₊₁) = lrp!(rule, w.laye
5358
insize = (64, 64, 3, 1)
5459
in_dense = 500
5560
out_dense = 100
56-
aₖ = randn(Float32, insize)
61+
aₖ = randn(T, insize)
5762

5863
layers = Dict(
5964
"MaxPool" => (MaxPool((3, 3); pad=0), aₖ),
6065
"Conv" => (Conv((3, 3), 3 => 2), aₖ),
61-
"Dense" => (Dense(in_dense, out_dense, relu), randn(Float32, in_dense, 1)),
66+
"Dense" => (Dense(in_dense, out_dense, relu), randn(T, in_dense, 1)),
6267
"WrappedDense" =>
63-
(TestWrapper(Dense(in_dense, out_dense, relu)), randn(Float32, in_dense, 1)),
68+
(TestWrapper(Dense(in_dense, out_dense, relu)), randn(T, in_dense, 1)),
6469
)
6570
rules = Dict(
6671
"ZeroRule" => ZeroRule(),
6772
"EpsilonRule" => EpsilonRule(),
6873
"GammaRule" => GammaRule(),
69-
"ZBoxRule" => ZBoxRule(),
74+
"ZBoxRule" => ZBoxRule(zero(T), oneunit(T)),
7075
)
7176

7277
SUITE["Layer"] = BenchmarkGroup([k for k in keys(layers)])

docs/literate/advanced_lrp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ input = reshape(x, 28, 28, 1, :);
2323
# we can also assign rules to each layer individually.
2424
# For this purpose, we create an array of rules that matches the length of the Flux chain:
2525
rules = [
26-
ZBoxRule(),
26+
ZBoxRule(0.0f0, 1.0f0),
2727
GammaRule(),
2828
GammaRule(),
2929
EpsilonRule(),

src/lrp_rules.jl

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,21 +107,41 @@ end
107107
modify_denominator(r::EpsilonRule, d) = stabilize_denom(d, r.ϵ)
108108

109109
"""
110-
ZBoxRule()
110+
ZBoxRule(low, high)
111111
112112
Constructor for LRP-``z^{\\mathcal{B}}``-rule.
113113
Commonly used on the first layer for pixel input.
114+
115+
The parameters `low` and `high` should be set to the lower and upper bounds of the input features,
116+
e.g. `0.0` and `1.0` for raw image data.
117+
It is also possible to provide two arrays of that match the input size.
114118
"""
115-
struct ZBoxRule <: AbstractLRPRule end
119+
struct ZBoxRule{T} <: AbstractLRPRule
120+
low::T
121+
high::T
122+
end
116123

117124
# The ZBoxRule requires its own implementation of relevance propagation.
118-
lrp!(::ZBoxRule, layer::Dense, Rₖ, aₖ, Rₖ₊₁) = lrp_zbox!(layer, Rₖ, aₖ, Rₖ₊₁)
119-
lrp!(::ZBoxRule, layer::Conv, Rₖ, aₖ, Rₖ₊₁) = lrp_zbox!(layer, Rₖ, aₖ, Rₖ₊₁)
125+
lrp!(r::ZBoxRule, layer::Dense, Rₖ, aₖ, Rₖ₊₁) = lrp_zbox!(r, layer, Rₖ, aₖ, Rₖ₊₁)
126+
lrp!(r::ZBoxRule, layer::Conv, Rₖ, aₖ, Rₖ₊₁) = lrp_zbox!(r, layer, Rₖ, aₖ, Rₖ₊₁)
127+
128+
_zbox_bound(T, c::Real, in_size) = fill(convert(T, c), in_size)
129+
function _zbox_bound(T, A::AbstractArray, in_size)
130+
size(A) != in_size && throw(
131+
ArgumentError(
132+
"Bounds `low`, `high` of ZBoxRule should either be scalar or match input size.",
133+
),
134+
)
135+
return convert.(T, A)
136+
end
120137

121-
function lrp_zbox!(layer::L, Rₖ::T1, aₖ::T1, Rₖ₊₁::T2) where {L,T1,T2}
122-
W, b = get_params(layer)
123-
l, h = fill.(extrema(aₖ), (size(aₖ),))
138+
function lrp_zbox!(r::ZBoxRule, layer::L, Rₖ::T1, aₖ::T1, Rₖ₊₁::T2) where {L,T1,T2}
139+
T = eltype(aₖ)
140+
in_size = size(aₖ)
141+
l = _zbox_bound(T, r.low, in_size)
142+
h = _zbox_bound(T, r.high, in_size)
124143

144+
W, b = get_params(layer)
125145
layer⁺ = set_params(layer, max.(0, W), max.(0, b)) # W⁺, b⁺
126146
layer⁻ = set_params(layer, min.(0, W), min.(0, b)) # W⁻, b⁻
127147

0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.

test/test_rules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ const RULES = Dict(
1111
"ZeroRule" => ZeroRule(),
1212
"EpsilonRule" => EpsilonRule(),
1313
"GammaRule" => GammaRule(),
14-
"ZBoxRule" => ZBoxRule(),
14+
"ZBoxRule" => ZBoxRule(0.0f0, 1.0f0),
1515
)
1616

1717
## Hand-written tests

test/test_vgg11.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ vgg11 = VGG11(; pretrain=false)
2626
model = flatten_model(strip_softmax(vgg11.layers))
2727

2828
function LRPCustom(model::Chain)
29-
return LRP(model, [ZBoxRule(), repeat([GammaRule()], length(model.layers) - 1)...])
29+
return LRP(
30+
model, [ZBoxRule(0.0f0, 1.0f0), repeat([GammaRule()], length(model.layers) - 1)...]
31+
)
3032
end
3133

3234
function test_vgg11(name, method; kwargs...)

0 commit comments

Comments
 (0)