Skip to content

Commit 0c4c786

Browse files
authored
Fix tests and benchmarks on TestWrapper (#30)
1 parent e4092d4 commit 0c4c786

File tree

2 files changed

+57
-45
lines changed

2 files changed

+57
-45
lines changed

benchmark/benchmarks.jl

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using BenchmarkTools
22
using Flux
33
using ExplainabilityMethods
4+
import ExplainabilityMethods: _modify_layer
45

56
on_CI = haskey(ENV, "GITHUB_ACTIONS")
67

@@ -29,7 +30,9 @@ SUITE = BenchmarkGroup()
2930
SUITE["VGG"] = BenchmarkGroup([k for k in keys(algs)])
3031
for (name, alg) in algs
3132
SUITE["VGG"][name] = BenchmarkGroup(["construct analyzer", "analyze"])
32-
SUITE["VGG"][name]["construct analyzer"] = @benchmarkable contruct_analyzer($(alg), $(model))
33+
SUITE["VGG"][name]["construct analyzer"] = @benchmarkable contruct_analyzer(
34+
$(alg), $(model)
35+
)
3336

3437
analyzer = alg(model)
3538
SUITE["VGG"][name]["analyze"] = @benchmarkable analyze($(img), $(analyzer))
@@ -39,36 +42,40 @@ end
3942
struct TestWrapper{T}
4043
layer::T
4144
end
42-
(l::TestWrapper)(x) = l.layer(x)
45+
(w::TestWrapper)(x) = w.layer(x)
46+
_modify_layer(r::AbstractLRPRule, w::TestWrapper) = _modify_layer(r, w.layer)
47+
(rule::ZBoxRule)(w::TestWrapper, aₖ, Rₖ₊₁) = rule(w.layer, aₖ, Rₖ₊₁)
4348

4449
# generate input for conv layers
45-
insize = (128, 128, 3, 1)
50+
insize = (64, 64, 3, 1)
51+
in_dense = 500
52+
out_dense = 100
4653
aₖ = randn(Float32, insize)
4754

4855
layers = Dict(
4956
"MaxPool" => (MaxPool((3, 3); pad=0), aₖ),
50-
"MeanPool" => (MeanPool((3, 3); pad=0), aₖ),
51-
"Conv" => (Conv((3, 3), 3 => 6), aₖ),
52-
"flatten" => (flatten, aₖ),
53-
"Dense" => (Dense(1000, 200, relu), randn(Float32, 1000)),
57+
"Conv" => (Conv((3, 3), 3 => 2), aₖ),
58+
"Dense" => (Dense(in_dense, out_dense, relu), randn(Float32, in_dense)),
59+
"WrappedDense" =>
60+
(TestWrapper(Dense(in_dense, out_dense, relu)), randn(Float32, in_dense)),
5461
)
5562
rules = Dict(
5663
"ZeroRule" => ZeroRule(),
5764
"EpsilonRule" => EpsilonRule(),
5865
"GammaRule" => GammaRule(),
5966
"ZBoxRule" => ZBoxRule(),
6067
)
61-
rulenames = [k for k in keys(rules)]
6268

6369
test_rule(rule, layer, aₖ, Rₖ₊₁) = rule(layer, aₖ, Rₖ₊₁) # for use with @benchmarkable macro
6470

71+
SUITE["Layer"] = BenchmarkGroup([k for k in keys(layers)])
6572
for (layername, (layer, aₖ)) in layers
66-
SUITE[layername] = BenchmarkGroup(rulenames)
67-
Rₖ₊₁ = layer(aₖ)
73+
SUITE["Layer"][layername] = BenchmarkGroup([k for k in keys(rules)])
6874

75+
Rₖ₊₁ = layer(aₖ)
6976
for (rulename, rule) in rules
70-
SUITE[layername][rulename] = BenchmarkGroup(["dispatch", "AD fallback"])
71-
SUITE[layername][rulename]["dispatch"] = @benchmarkable test_rule($(rule), $(layer), $(aₖ), $(Rₖ₊₁))
72-
SUITE[layername][rulename]["AD fallback"] = @benchmarkable test_rule($(rule), $(TestWrapper(layer)), $(aₖ), $(Rₖ₊₁))
77+
SUITE["Layer"][layername][rulename] = @benchmarkable test_rule(
78+
$(rule), $(layer), $(aₖ), $(Rₖ₊₁)
79+
)
7380
end
7481
end

test/test_rules.jl

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
using ExplainabilityMethods
22
using ExplainabilityMethods: modify_params
3+
import ExplainabilityMethods: _modify_layer
34
using Flux
45
using LinearAlgebra
56
using ReferenceTests
67
using Random
78

89
const RULES = Dict(
9-
"ZeroRule" => ZeroRule,
10-
"EpsilonRule" => EpsilonRule,
11-
"GammaRule" => GammaRule,
12-
"ZBoxRule" => ZBoxRule,
10+
"ZeroRule" => ZeroRule(),
11+
"EpsilonRule" => EpsilonRule(),
12+
"GammaRule" => GammaRule(),
13+
"ZBoxRule" => ZBoxRule(),
1314
)
1415

1516
## Hand-written tests
@@ -54,17 +55,16 @@ end
5455

5556
## Test Dense layer
5657
# Define Dense test input
57-
ins = 20 # input dimension
58-
outs = 10 # output dimension
59-
aₖ = pseudorandn(ins)
58+
ins_dense = 20 # input dimension
59+
outs_dense = 10 # output dimension
60+
aₖ = pseudorandn(ins_dense)
6061

6162
layers = Dict(
62-
"Dense_relu" => Dense(ins, outs, relu; init=pseudorandn),
63-
"Dense_identity" => Dense(Matrix(I, outs, ins), false, identity),
63+
"Dense_relu" => Dense(ins_dense, outs_dense, relu; init=pseudorandn),
64+
"Dense_identity" => Dense(Matrix(I, outs_dense, ins_dense), false, identity),
6465
)
6566
@testset "Dense" begin
66-
for (rulename, ruletype) in RULES
67-
rule = ruletype()
67+
for (rulename, rule) in RULES
6868
@testset "$rulename" begin
6969
for (layername, layer) in layers
7070
@testset "$layername" begin
@@ -76,10 +76,10 @@ layers = Dict(
7676

7777
# println(Rₖ)
7878
if rulename == "Dense_identity"
79-
# First `outs` dimensions should propagate
79+
# First `outs_dense` dimensions should propagate
8080
# activations as relevances, rest should be ≈ 0.
81-
@test Rₖ[1:outs] aₖ[1:outs]
82-
@test all(Rₖ[outs:end] .< 1e-8)
81+
@test Rₖ[1:outs_dense] aₖ[1:outs_dense]
82+
@test all(Rₖ[outs_dense:end] .< 1e-8)
8383
end
8484

8585
@test_reference "references/rules/$rulename/$layername.jld2" Dict(
@@ -103,8 +103,7 @@ equalpairs = Dict( # these pairs of layers are all equal
103103
)
104104

105105
@testset "PoolingLayers" begin
106-
for (rulename, ruletype) in RULES
107-
rule = ruletype()
106+
for (rulename, rule) in RULES
108107
@testset "$rulename" begin
109108
for (layername, layers) in equalpairs
110109
@testset "$layername" begin
@@ -139,8 +138,7 @@ layers = Dict(
139138
"AlphaDropout" => AlphaDropout(0.2),
140139
)
141140
@testset "Other Layers" begin
142-
for (rulename, ruletype) in RULES
143-
rule = ruletype()
141+
for (rulename, rule) in RULES
144142
@testset "$rulename" begin
145143
for (layername, layer) in layers
146144
@testset "$layername" begin
@@ -164,26 +162,33 @@ end
164162
struct TestWrapper{T}
165163
layer::T
166164
end
167-
(l::TestWrapper)(x) = l.layer(x)
165+
(w::TestWrapper)(x) = w.layer(x)
166+
_modify_layer(r::AbstractLRPRule, w::TestWrapper) = _modify_layer(r, w.layer)
167+
(rule::ZBoxRule)(w::TestWrapper, aₖ, Rₖ₊₁) = rule(w.layer, aₖ, Rₖ₊₁)
168168

169169
layers = Dict(
170170
"Conv" => (Conv((3, 3), 2 => 4; init=pseudorandn), aₖ),
171+
"Dense_relu" =>
172+
(Dense(ins_dense, outs_dense, relu; init=pseudorandn), pseudorandn(ins_dense)),
171173
"flatten" => (flatten, aₖ),
172-
"Dense" => (Dense(20, 10, relu; init=pseudorandn), pseudorandn(20)),
173174
)
174175
@testset "Custom layers" begin
175-
for (layername, (layer, aₖ)) in layers
176-
@testset "$layername" begin
177-
rule = ZeroRule()
178-
wrapped_layer = TestWrapper(layer)
179-
Rₖ₊₁ = wrapped_layer(aₖ)
180-
Rₖ = rule(wrapped_layer, aₖ, Rₖ₊₁)
181-
182-
@test typeof(Rₖ) == typeof(aₖ)
183-
@test size(Rₖ) == size(aₖ)
184-
185-
@test_reference "references/rules/ZeroRule/$layername.jld2" Dict("R" => Rₖ) by =
186-
(r, a) -> isapprox(r["R"], a["R"]; rtol=0.02)
176+
for (rulename, rule) in RULES
177+
@testset "$rulename" begin
178+
for (layername, (layer, aₖ)) in layers
179+
@testset "$layername" begin
180+
wrapped_layer = TestWrapper(layer)
181+
Rₖ₊₁ = wrapped_layer(aₖ)
182+
Rₖ = rule(wrapped_layer, aₖ, Rₖ₊₁)
183+
184+
@test typeof(Rₖ) == typeof(aₖ)
185+
@test size(Rₖ) == size(aₖ)
186+
187+
@test_reference "references/rules/$rulename/$layername.jld2" Dict(
188+
"R" => Rₖ
189+
) by = (r, a) -> isapprox(r["R"], a["R"]; rtol=0.02)
190+
end
191+
end
187192
end
188193
end
189194
end

0 commit comments

Comments
 (0)