1
1
using ExplainabilityMethods
2
2
using ExplainabilityMethods: modify_params
3
+ import ExplainabilityMethods: _modify_layer
3
4
using Flux
4
5
using LinearAlgebra
5
6
using ReferenceTests
6
7
using Random
7
8
8
9
const RULES = Dict (
9
- " ZeroRule" => ZeroRule,
10
- " EpsilonRule" => EpsilonRule,
11
- " GammaRule" => GammaRule,
12
- " ZBoxRule" => ZBoxRule,
10
+ " ZeroRule" => ZeroRule () ,
11
+ " EpsilonRule" => EpsilonRule () ,
12
+ " GammaRule" => GammaRule () ,
13
+ " ZBoxRule" => ZBoxRule () ,
13
14
)
14
15
15
16
# # Hand-written tests
54
55
55
56
# # Test Dense layer
56
57
# Define Dense test input
57
- ins = 20 # input dimension
58
- outs = 10 # output dimension
59
- aₖ = pseudorandn (ins )
58
+ ins_dense = 20 # input dimension
59
+ outs_dense = 10 # output dimension
60
+ aₖ = pseudorandn (ins_dense )
60
61
61
62
layers = Dict (
62
- " Dense_relu" => Dense (ins, outs , relu; init= pseudorandn),
63
- " Dense_identity" => Dense (Matrix (I, outs, ins ), false , identity),
63
+ " Dense_relu" => Dense (ins_dense, outs_dense , relu; init= pseudorandn),
64
+ " Dense_identity" => Dense (Matrix (I, outs_dense, ins_dense ), false , identity),
64
65
)
65
66
@testset " Dense" begin
66
- for (rulename, ruletype) in RULES
67
- rule = ruletype ()
67
+ for (rulename, rule) in RULES
68
68
@testset " $rulename " begin
69
69
for (layername, layer) in layers
70
70
@testset " $layername " begin
@@ -76,10 +76,10 @@ layers = Dict(
76
76
77
77
# println(Rₖ)
78
78
if rulename == " Dense_identity"
79
- # First `outs ` dimensions should propagate
79
+ # First `outs_dense ` dimensions should propagate
80
80
# activations as relevances, rest should be ≈ 0.
81
- @test Rₖ[1 : outs ] ≈ aₖ[1 : outs ]
82
- @test all (Rₖ[outs : end ] .< 1e-8 )
81
+ @test Rₖ[1 : outs_dense ] ≈ aₖ[1 : outs_dense ]
82
+ @test all (Rₖ[outs_dense : end ] .< 1e-8 )
83
83
end
84
84
85
85
@test_reference " references/rules/$rulename /$layername .jld2" Dict (
@@ -103,8 +103,7 @@ equalpairs = Dict( # these pairs of layers are all equal
103
103
)
104
104
105
105
@testset " PoolingLayers" begin
106
- for (rulename, ruletype) in RULES
107
- rule = ruletype ()
106
+ for (rulename, rule) in RULES
108
107
@testset " $rulename " begin
109
108
for (layername, layers) in equalpairs
110
109
@testset " $layername " begin
@@ -139,8 +138,7 @@ layers = Dict(
139
138
" AlphaDropout" => AlphaDropout (0.2 ),
140
139
)
141
140
@testset " Other Layers" begin
142
- for (rulename, ruletype) in RULES
143
- rule = ruletype ()
141
+ for (rulename, rule) in RULES
144
142
@testset " $rulename " begin
145
143
for (layername, layer) in layers
146
144
@testset " $layername " begin
@@ -164,26 +162,33 @@ end
164
162
struct TestWrapper{T}
165
163
layer:: T
166
164
end
167
- (l:: TestWrapper )(x) = l. layer (x)
165
+ (w:: TestWrapper )(x) = w. layer (x)
166
+ _modify_layer (r:: AbstractLRPRule , w:: TestWrapper ) = _modify_layer (r, w. layer)
167
+ (rule:: ZBoxRule )(w:: TestWrapper , aₖ, Rₖ₊₁) = rule (w. layer, aₖ, Rₖ₊₁)
168
168
169
169
layers = Dict (
170
170
" Conv" => (Conv ((3 , 3 ), 2 => 4 ; init= pseudorandn), aₖ),
171
+ " Dense_relu" =>
172
+ (Dense (ins_dense, outs_dense, relu; init= pseudorandn), pseudorandn (ins_dense)),
171
173
" flatten" => (flatten, aₖ),
172
- " Dense" => (Dense (20 , 10 , relu; init= pseudorandn), pseudorandn (20 )),
173
174
)
174
175
@testset " Custom layers" begin
175
- for (layername, (layer, aₖ)) in layers
176
- @testset " $layername " begin
177
- rule = ZeroRule ()
178
- wrapped_layer = TestWrapper (layer)
179
- Rₖ₊₁ = wrapped_layer (aₖ)
180
- Rₖ = rule (wrapped_layer, aₖ, Rₖ₊₁)
181
-
182
- @test typeof (Rₖ) == typeof (aₖ)
183
- @test size (Rₖ) == size (aₖ)
184
-
185
- @test_reference " references/rules/ZeroRule/$layername .jld2" Dict (" R" => Rₖ) by =
186
- (r, a) -> isapprox (r[" R" ], a[" R" ]; rtol= 0.02 )
176
+ for (rulename, rule) in RULES
177
+ @testset " $rulename " begin
178
+ for (layername, (layer, aₖ)) in layers
179
+ @testset " $layername " begin
180
+ wrapped_layer = TestWrapper (layer)
181
+ Rₖ₊₁ = wrapped_layer (aₖ)
182
+ Rₖ = rule (wrapped_layer, aₖ, Rₖ₊₁)
183
+
184
+ @test typeof (Rₖ) == typeof (aₖ)
185
+ @test size (Rₖ) == size (aₖ)
186
+
187
+ @test_reference " references/rules/$rulename /$layername .jld2" Dict (
188
+ " R" => Rₖ
189
+ ) by = (r, a) -> isapprox (r[" R" ], a[" R" ]; rtol= 0.02 )
190
+ end
191
+ end
187
192
end
188
193
end
189
194
end
0 commit comments