Skip to content

Commit 0e99f12

Browse files
authored
Fix RangeNormalizer (#17)
1 parent faffa98 commit 0e99f12

File tree

3 files changed

+22
-9
lines changed

3 files changed

+22
-9
lines changed

src/heatmap.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function (::MaxAbsNormalizer)(img::SingleChannelImage)
2828
end
2929

3030
struct RangeNormalizer <: AbstractActivationNormalizer end
31-
function normalize(img::SingleChannelImage)
31+
function (::RangeNormalizer)(img::SingleChannelImage)
3232
min, max = extrema(img)
3333
return (img .- min) / (max - min)
3434
end

test/runtests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ using Test
88
@testset "Neuron selection" begin
99
include("test_neuron_selection.jl")
1010
end
11-
@testset "LRP rules" begin
12-
include("test_rules.jl")
13-
end
1411
@testset "Heatmaps" begin
1512
include("test_heatmaps.jl")
1613
end
14+
@testset "LRP rules" begin
15+
include("test_rules.jl")
16+
end
1717
@testset "VGG-19" begin
1818
include("test_vgg19.jl")
1919
end

test/test_heatmaps.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,34 @@
11
using ExplainabilityMethods
2-
using ExplainabilityMethods: SumReducer
2+
using ExplainabilityMethods: MaxAbsNormalizer, RangeNormalizer
3+
using ExplainabilityMethods: MaxAbsReducer, SumReducer
34

45
# Defaults assume nchannels=3
56
A = rand(Float32, 3, 4, 5)
67
B = reshape(A, 3, 1, 4, 5, 1)
7-
@test heatmap(A) heatmap(B)
8+
@test heatmap(A; normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer())
9+
heatmap(B; normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer())
10+
@test heatmap(A; normalizer=RangeNormalizer(), reducer=SumReducer())
11+
heatmap(B; normalizer=RangeNormalizer(), reducer=SumReducer())
812

913
A = rand(Float32, 3, 4, 5)
1014
B = reshape(A, 3, 1, 4, 1, 5)
11-
@test heatmap(A; reducer=SumReducer()) heatmap(B; reducer=SumReducer())
15+
@test heatmap(A; normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer())
16+
heatmap(B; normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer())
17+
@test heatmap(A; normalizer=RangeNormalizer(), reducer=SumReducer())
18+
heatmap(B; normalizer=RangeNormalizer(), reducer=SumReducer())
1219

1320
# Test with single channel
1421
A = rand(Float32, 4, 5)
1522
B = reshape(A, 4, 1, 1, 5, 1)
16-
@test heatmap(A; nchannels=1) heatmap(B; nchannels=1)
23+
@test heatmap(A; nchannels=1, normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer())
24+
heatmap(B; nchannels=1, normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer())
25+
@test heatmap(A; nchannels=1, normalizer=RangeNormalizer(), reducer=SumReducer())
26+
heatmap(B; nchannels=1, normalizer=RangeNormalizer(), reducer=SumReducer())
1727

1828
# Test with 2 color channels
1929
A = rand(Float32, 2, 3, 4)
2030
B = reshape(A, 2, 1, 1, 3, 4)
21-
@test heatmap(A; nchannels=2) heatmap(B; nchannels=2)
31+
@test heatmap(A; nchannels=2, normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer())
32+
heatmap(B; nchannels=2, normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer())
33+
@test heatmap(A; nchannels=2, normalizer=RangeNormalizer(), reducer=SumReducer())
34+
heatmap(B; nchannels=2, normalizer=RangeNormalizer(), reducer=SumReducer())

0 commit comments

Comments
 (0)