@@ -6,6 +6,7 @@ const GRADIENT_ANALYZERS = Dict(
6
6
" InputTimesGradient" => InputTimesGradient,
7
7
" SmoothGrad" => m -> SmoothGrad (m, 5 , 0.1 , MersenneTwister (123 )),
8
8
" IntegratedGradients" => m -> IntegratedGradients (m, 5 ),
9
+ " GradCAM" => m -> GradCAM (m[1 ], m[2 ]),
9
10
)
10
11
11
12
input_size = (32 , 32 , 3 , 1 )
@@ -67,17 +68,13 @@ function test_cnn(name, method)
67
68
println (" Timing $name ..." )
68
69
print (" cold:" )
69
70
@time expl = analyze (input, analyzer)
70
-
71
- @test size (expl. val) == size (input)
72
71
@test_reference " references/cnn/$(name) _max.jld2" Dict (" expl" => expl. val) by =
73
72
(r, a) -> isapprox (r[" expl" ], a[" expl" ]; rtol= 0.05 )
74
73
end
75
74
@testset " Neuron selection" begin
76
75
analyzer = method (model)
77
76
print (" warm:" )
78
77
@time expl = analyze (input, analyzer, 1 )
79
-
80
- @test size (expl. val) == size (input)
81
78
@test_reference " references/cnn/$(name) _ns1.jld2" Dict (" expl" => expl. val) by =
82
79
(r, a) -> isapprox (r[" expl" ], a[" expl" ]; rtol= 0.05 )
83
80
end
0 commit comments