Skip to content

Commit 4ec4908

Browse files
authored
Update heatmapping normalizer (#57)
* Update heatmapping normalizer to use ColorSchemes 3.18 * Update and rename heatmap tests * Fix for Julia 1.0 compat
1 parent 8641dfb commit 4ec4908

16 files changed

+22
-38
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1717
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1818

1919
[compat]
20-
ColorSchemes = "3"
20+
ColorSchemes = "3.18"
2121
Distributions = "0.25"
2222
Flux = "0.12, 0.13"
2323
ImageCore = "0.8, 0.9"

docs/literate/example.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ heatmap(input, analyzer)
130130
using ColorSchemes
131131
heatmap(expl; cs=ColorSchemes.jet)
132132
#
133-
heatmap(expl; reduce=:sum, normalize=:extrema, cs=ColorSchemes.inferno)
133+
heatmap(expl; reduce=:sum, rangescale=:extrema, cs=ColorSchemes.inferno)
134134

135135
# This also works with batches
136-
mosaic(heatmap(expl_batch; normalize=:extrema, cs=ColorSchemes.inferno); nrow=10)
136+
mosaic(heatmap(expl_batch; rangescale=:extrema, cs=ColorSchemes.inferno); nrow=10)
137137

138138
# For the full list of keyword arguments, refer to the [`heatmap`](@ref) documentation.

src/heatmap.jl

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# NOTE: Heatmapping assumes Flux's WHCN convention (width, height, color channels, batch size).
22

33
const HEATMAPPING_PRESETS = Dict{Symbol,Tuple{ColorScheme,Symbol,Symbol}}(
4-
# Analyzer => (colorscheme, reduce, normalize)
4+
# Analyzer => (colorscheme, reduce, rangescale)
55
:LRP => (ColorSchemes.bwr, :sum, :centered), # attribution
66
:InputTimesGradient => (ColorSchemes.bwr, :sum, :centered), # attribution
77
:Gradient => (ColorSchemes.grays, :norm, :extrema), # gradient
@@ -29,7 +29,7 @@ Assumes Flux's WHCN convention (width, height, color channels, batch size).
2929
- `:maxabs`: compute `maximum(abs, x)` over the color channels in
3030
When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
3131
When calling `heatmap` with an array, the default is `:sum`.
32-
- `normalize::Symbol`: How the color channel reduced heatmap is normalized before the colorscheme is applied.
32+
- `rangescale::Symbol`: How the color channel reduced heatmap is normalized before the colorscheme is applied.
3333
Can be either `:extrema` or `:centered`.
3434
When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
3535
When calling `heatmap` with an array, the default for use with the `bwr` colorscheme is `:centered`.
@@ -43,7 +43,7 @@ function heatmap(
4343
attr::AbstractArray{T,N};
4444
cs::ColorScheme=ColorSchemes.bwr,
4545
reduce::Symbol=:sum,
46-
normalize::Symbol=:centered,
46+
rangescale::Symbol=:centered,
4747
permute::Bool=true,
4848
unpack_singleton::Bool=true,
4949
) where {T,N}
@@ -55,18 +55,18 @@ function heatmap(
5555
),
5656
)
5757
if unpack_singleton && size(attr, 4) == 1
58-
return _heatmap(attr[:, :, :, 1], cs, reduce, normalize, permute)
58+
return _heatmap(attr[:, :, :, 1], cs, reduce, rangescale, permute)
5959
end
60-
return map(a -> _heatmap(a, cs, reduce, normalize, permute), eachslice(attr; dims=4))
60+
return map(a -> _heatmap(a, cs, reduce, rangescale, permute), eachslice(attr; dims=4))
6161
end
6262

6363
# Use HEATMAPPING_PRESETS for default kwargs when dispatching on Explanation
6464
function heatmap(expl::Explanation; permute::Bool=true, kwargs...)
65-
_cs, _reduce, _normalize = HEATMAPPING_PRESETS[expl.analyzer]
65+
_cs, _reduce, _rangescale = HEATMAPPING_PRESETS[expl.analyzer]
6666
return heatmap(
6767
expl.attribution;
6868
reduce=get(kwargs, :reduce, _reduce),
69-
normalize=get(kwargs, :normalize, _normalize),
69+
rangescale=get(kwargs, :rangescale, _rangescale),
7070
cs=get(kwargs, :cs, _cs),
7171
permute=permute,
7272
)
@@ -81,28 +81,12 @@ function _heatmap(
8181
attr::AbstractArray{T,3},
8282
cs::ColorScheme,
8383
reduce::Symbol,
84-
normalize::Symbol,
84+
rangescale::Symbol,
8585
permute::Bool,
8686
) where {T<:Real}
87-
img = _normalize(dropdims(_reduce(attr, reduce); dims=3), normalize)
87+
img = dropdims(_reduce(attr, reduce); dims=3)
8888
permute && (img = permutedims(img))
89-
return ColorSchemes.get(cs, img)
90-
end
91-
92-
# Normalize activations across pixels
93-
function _normalize(attr, method::Symbol)
94-
if method == :centered
95-
min, max = (-1, 1) .* maximum(abs, attr)
96-
elseif method == :extrema
97-
min, max = extrema(attr)
98-
else
99-
throw(
100-
ArgumentError(
101-
"Color scheme normalizer :$method not supported, `normalize` should be :extrema or :centered",
102-
),
103-
)
104-
end
105-
return (attr .- min) / (max - min)
89+
return ColorSchemes.get(cs, img, rangescale)
10690
end
10791

10892
# Reduce attributions across color channels into a single scalar – assumes WHCN convention
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)