Skip to content

Commit f120801

Browse files
authored
Refactor heatmapping using symbols (#32)
1 parent 93413f2 commit f120801

File tree

6 files changed

+76
-68
lines changed

6 files changed

+76
-68
lines changed

src/heatmap.jl

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,73 @@
1-
const SingleChannelImage = AbstractArray{<:Real,2}
2-
abstract type AbstractActivationNormalizer end
3-
abstract type AbstractColorReducer end
1+
# NOTE: Heatmapping assumes Flux's WHCN convention (width, height, color channels, batch size).
42

53
"""
64
heatmap(expl; kwargs...)
75
86
Visualize explanation.
7+
Assumes the Flux's WHCN convention (width, height, color channels, batch size).
8+
9+
## Keyword arguments
10+
-`cs::ColorScheme`: ColorScheme that is applied. Defaults to `ColorSchemes.bwr`.
11+
-`reduce::Symbol`: How the color channels are reduced to a single number to apply a colorscheme.
12+
Can be either `:sum` or `:maxabs`. `:sum` sums up all color channels for each pixel.
13+
`:maxabs` selects the `maximum(abs, x)` over the color channel in each pixel.
14+
Default is `:sum`.
15+
-`normalize::Symbol`: How the color channel reduced heatmap is normalized before the colorscheme is applied.
16+
Can be either `:extrema` or `:centered`. Default for use with colorscheme `bwr` is `:centered`.
917
"""
1018
function heatmap(
1119
expl::AbstractArray;
1220
cs::ColorScheme=ColorSchemes.bwr,
13-
normalizer::AbstractActivationNormalizer=MaxAbsNormalizer(),
14-
reducer::AbstractColorReducer=SumReducer(),
15-
nchannels::Int=3,
21+
reduce::Symbol=:sum,
22+
normalize::Symbol=:centered,
1623
permute::Bool=true,
1724
)
18-
img = normalizer(reducer(drop_singleton_dims(expl), nchannels))
25+
_size = size(expl)
26+
length(_size) != 4 && throw(
27+
DomainError(
28+
_size,
29+
"""heatmap assumes Flux's WHCN convention (width, height, color channels, batch size) for the input.
30+
Please reshape your attribution to match this format if your model doesn't adhere to this convention.""",
31+
),
32+
)
33+
_size[end] != 1 && throw(
34+
DomainError(
35+
_size[end],
36+
"""heatmap is only applicable to a single attribution, got a batch dimension of $(_size[end]).""",
37+
),
38+
)
39+
# drop batch dim -> reduce color channels -> normalize image -> apply color scheme
40+
img = _normalize(_reduce(dropdims(expl; dims=4), reduce), normalize)
1941
permute && (img = permutedims(img))
20-
return get(cs, img)
42+
return ColorSchemes.get(cs, img)
2143
end
2244

2345
# Normalize activations across pixels
24-
struct MaxAbsNormalizer <: AbstractActivationNormalizer end
25-
function (::MaxAbsNormalizer)(img::SingleChannelImage)
26-
absmax = maximum(abs, img)
27-
return img / (2 * absmax) .+ 0.5
28-
end
29-
30-
struct RangeNormalizer <: AbstractActivationNormalizer end
31-
function (::RangeNormalizer)(img::SingleChannelImage)
32-
min, max = extrema(img)
33-
return (img .- min) / (max - min)
46+
function _normalize(attr, method::Symbol)
47+
if method == :centered
48+
min, max = (-1, 1) .* maximum(abs, attr)
49+
elseif method == :extrema
50+
min, max = extrema(attr)
51+
else
52+
throw(
53+
ArgumentError(
54+
"Color scheme normalizer :$method not supported, `normalize` should be :extrema or :centered",
55+
),
56+
)
57+
end
58+
return (attr .- min) / (max - min)
3459
end
3560

3661
# Reduces activation in a pixel with multiple color channels into a single activation
37-
struct MaxAbsReducer <: AbstractColorReducer end
38-
function (::MaxAbsReducer)(img, nchannels)
39-
nchannels == 1 && return img
40-
dim = find_color_channel(img, nchannels)
41-
return dropdims(maximum(abs, img; dims=dim); dims=dim)
42-
end
43-
44-
struct SumReducer <: AbstractColorReducer end
45-
function (::SumReducer)(img, nchannels)
46-
nchannels == 1 && return img
47-
dim = find_color_channel(img, nchannels)
48-
return dropdims(sum(img; dims=dim); dims=dim)
49-
end
50-
51-
function find_color_channel(img, nchannels)
52-
colordims = findall(size(img) .== nchannels)
53-
if length(colordims) == 0
54-
throw(ArgumentError("No dimension with nchannels=$nchannels color channels found."))
55-
elseif length(colordims) > 1
56-
throw(ArgumentError("Several dimensions of length $nchannels found."))
62+
function _reduce(attr, method::Symbol)
63+
if method == :maxabs
64+
return dropdims(maximum(abs, attr; dims=3); dims=3)
65+
elseif method == :sum
66+
return dropdims(sum(attr; dims=3); dims=3)
5767
end
58-
return first(colordims)
68+
throw(
69+
ArgumentError(
70+
"Color channel reducer :$method not supported, `reduce` should be :maxabs or :sum",
71+
),
72+
)
5973
end
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
▀▀
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
▀▀
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
▀▀
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
▀▀

test/test_heatmaps.jl

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,24 @@
11
using ExplainabilityMethods
2-
using ExplainabilityMethods: MaxAbsNormalizer, RangeNormalizer
3-
using ExplainabilityMethods: MaxAbsReducer, SumReducer
42

5-
# Defaults assume nchannels=3
6-
A = rand(Float32, 3, 4, 5)
7-
B = reshape(A, 3, 1, 4, 5, 1)
8-
@test heatmap(A; normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer())
9-
heatmap(B; normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer())
10-
@test heatmap(A; normalizer=RangeNormalizer(), reducer=SumReducer())
11-
heatmap(B; normalizer=RangeNormalizer(), reducer=SumReducer())
3+
# NOTE: Heatmapping assumes Flux's WHCN convention (width, height, color channels, batch size).
4+
shape = (2, 2, 3, 1)
5+
A = reshape(collect(Float32, 1:prod(shape)), shape)
126

13-
A = rand(Float32, 3, 4, 5)
14-
B = reshape(A, 3, 1, 4, 1, 5)
15-
@test heatmap(A; normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer())
16-
heatmap(B; normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer())
17-
@test heatmap(A; normalizer=RangeNormalizer(), reducer=SumReducer())
18-
heatmap(B; normalizer=RangeNormalizer(), reducer=SumReducer())
7+
reducers = [:sum, :maxabs]
8+
normalizers = [:extrema, :centered]
9+
for r in reducers
10+
for n in normalizers
11+
h = @inferred heatmap(A; reduce=r, normalize=n)
12+
@test_reference "references/heatmaps/reduce_$(r)_normalize_$(n).txt" h
13+
end
14+
end
1915

20-
# Test with single channel
21-
A = rand(Float32, 4, 5)
22-
B = reshape(A, 4, 1, 1, 5, 1)
23-
@test heatmap(A; nchannels=1, normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer())
24-
heatmap(B; nchannels=1, normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer())
25-
@test heatmap(A; nchannels=1, normalizer=RangeNormalizer(), reducer=SumReducer())
26-
heatmap(B; nchannels=1, normalizer=RangeNormalizer(), reducer=SumReducer())
16+
@test_throws ArgumentError heatmap(A, reduce=:foo)
17+
@test_throws ArgumentError heatmap(A, normalize=:bar)
2718

28-
# Test with 2 color channels
29-
A = rand(Float32, 2, 3, 4)
30-
B = reshape(A, 2, 1, 1, 3, 4)
31-
@test heatmap(A; nchannels=2, normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer())
32-
heatmap(B; nchannels=2, normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer())
33-
@test heatmap(A; nchannels=2, normalizer=RangeNormalizer(), reducer=SumReducer())
34-
heatmap(B; nchannels=2, normalizer=RangeNormalizer(), reducer=SumReducer())
19+
B = reshape(A, 2, 2, 3, 1, 1)
20+
@test_throws DomainError heatmap(B)
21+
B = reshape(A, 2, 2, 3)
22+
@test_throws DomainError heatmap(B)
23+
B = reshape(A, 2, 2, 1, 3)
24+
@test_throws DomainError heatmap(B)

0 commit comments

Comments
 (0)