@@ -14,14 +14,12 @@ batchsize = 15
14
14
15
15
model = Chain (Dense (ins, 15 , relu; init= pseudorand), Dense (15 , outs, relu; init= pseudorand))
16
16
17
- # Input 1 w/o batch dimension
18
- input1_no_bd = rand (MersenneTwister (1 ), Float32, ins)
19
17
# Input 1 with batch dimension
20
- input1_bd = reshape (input1_no_bd , ins, 1 )
18
+ input1 = rand ( MersenneTwister ( 1 ), Float32 , ins, 1 )
21
19
# Input 2 with batch dimension
22
- input2_bd = rand (MersenneTwister (2 ), Float32, ins, 1 )
20
+ input2 = rand (MersenneTwister (2 ), Float32, ins, 1 )
23
21
# Batch containing inputs 1 & 2
24
- input_batch = cat (input1_bd, input2_bd ; dims= 2 )
22
+ input_batch = cat (input1, input2 ; dims= 2 )
25
23
26
24
ANALYZERS = Dict (
27
25
" Gradient" => Gradient,
@@ -33,25 +31,21 @@ ANALYZERS = Dict(
33
31
34
32
for (name, method) in ANALYZERS
35
33
@testset " $name " begin
36
- # Using `add_batch_dim=true` should result in same explanation
37
- # as input reshaped to have a batch dimension
38
34
analyzer = method (model)
39
- expl1_no_bd = analyzer (input1_no_bd; add_batch_dim= true )
40
- analyzer = method (model)
41
- expl1_bd = analyzer (input1_bd)
42
- @test expl1_bd. val ≈ expl1_no_bd. val
35
+ expl1 = analyzer (input1)
36
+ @test expl1. val ≈ expl1. val
43
37
44
38
# Analyzing a batch should have the same result
45
39
# as analyzing inputs in batch individually
46
40
analyzer = method (model)
47
- expl2_bd = analyzer (input2_bd )
41
+ expl2 = analyzer (input2 )
48
42
analyzer = method (model)
49
43
expl_batch = analyzer (input_batch)
50
- @test expl1_bd . val ≈ expl_batch. val[:, 1 ]
44
+ @test expl1 . val ≈ expl_batch. val[:, 1 ]
51
45
if ! (analyzer isa NoiseAugmentation)
52
46
# NoiseAugmentation methods generate random numbers for the entire batch.
53
47
# therefore explanations don't match except for the first input in the batch.
54
- @test expl2_bd . val ≈ expl_batch. val[:, 2 ]
48
+ @test expl2 . val ≈ expl_batch. val[:, 2 ]
55
49
end
56
50
end
57
51
end
0 commit comments