|
1 |
| -const SingleChannelImage = AbstractArray{<:Real,2} |
2 |
| -abstract type AbstractActivationNormalizer end |
3 |
| -abstract type AbstractColorReducer end |
| 1 | +# NOTE: Heatmapping assumes Flux's WHCN convention (width, height, color channels, batch size). |
4 | 2 |
|
5 | 3 | """
|
6 | 4 | heatmap(expl; kwargs...)
|
7 | 5 |
|
8 | 6 | Visualize explanation.
|
| 7 | +Assumes the Flux's WHCN convention (width, height, color channels, batch size). |
| 8 | +
|
| 9 | +## Keyword arguments |
| 10 | +-`cs::ColorScheme`: ColorScheme that is applied. Defaults to `ColorSchemes.bwr`. |
| 11 | +-`reduce::Symbol`: How the color channels are reduced to a single number to apply a colorscheme. |
| 12 | + Can be either `:sum` or `:maxabs`. `:sum` sums up all color channels for each pixel. |
| 13 | + `:maxabs` selects the `maximum(abs, x)` over the color channel in each pixel. |
| 14 | + Default is `:sum`. |
| 15 | +-`normalize::Symbol`: How the color channel reduced heatmap is normalized before the colorscheme is applied. |
| 16 | + Can be either `:extrema` or `:centered`. Default for use with colorscheme `bwr` is `:centered`. |
9 | 17 | """
|
10 | 18 | function heatmap(
|
11 | 19 | expl::AbstractArray;
|
12 | 20 | cs::ColorScheme=ColorSchemes.bwr,
|
13 |
| - normalizer::AbstractActivationNormalizer=MaxAbsNormalizer(), |
14 |
| - reducer::AbstractColorReducer=SumReducer(), |
15 |
| - nchannels::Int=3, |
| 21 | + reduce::Symbol=:sum, |
| 22 | + normalize::Symbol=:centered, |
16 | 23 | permute::Bool=true,
|
17 | 24 | )
|
18 |
| - img = normalizer(reducer(drop_singleton_dims(expl), nchannels)) |
| 25 | + _size = size(expl) |
| 26 | + length(_size) != 4 && throw( |
| 27 | + DomainError( |
| 28 | + _size, |
| 29 | + """heatmap assumes Flux's WHCN convention (width, height, color channels, batch size) for the input. |
| 30 | + Please reshape your attribution to match this format if your model doesn't adhere to this convention.""", |
| 31 | + ), |
| 32 | + ) |
| 33 | + _size[end] != 1 && throw( |
| 34 | + DomainError( |
| 35 | + _size[end], |
| 36 | + """heatmap is only applicable to a single attribution, got a batch dimension of $(_size[end]).""", |
| 37 | + ), |
| 38 | + ) |
| 39 | + # drop batch dim -> reduce color channels -> normalize image -> apply color scheme |
| 40 | + img = _normalize(_reduce(dropdims(expl; dims=4), reduce), normalize) |
19 | 41 | permute && (img = permutedims(img))
|
20 |
| - return get(cs, img) |
| 42 | + return ColorSchemes.get(cs, img) |
21 | 43 | end
|
22 | 44 |
|
23 | 45 | # Normalize activations across pixels
|
24 |
| -struct MaxAbsNormalizer <: AbstractActivationNormalizer end |
25 |
| -function (::MaxAbsNormalizer)(img::SingleChannelImage) |
26 |
| - absmax = maximum(abs, img) |
27 |
| - return img / (2 * absmax) .+ 0.5 |
28 |
| -end |
29 |
| - |
30 |
| -struct RangeNormalizer <: AbstractActivationNormalizer end |
31 |
| -function (::RangeNormalizer)(img::SingleChannelImage) |
32 |
| - min, max = extrema(img) |
33 |
| - return (img .- min) / (max - min) |
| 46 | +function _normalize(attr, method::Symbol) |
| 47 | + if method == :centered |
| 48 | + min, max = (-1, 1) .* maximum(abs, attr) |
| 49 | + elseif method == :extrema |
| 50 | + min, max = extrema(attr) |
| 51 | + else |
| 52 | + throw( |
| 53 | + ArgumentError( |
| 54 | + "Color scheme normalizer :$method not supported, `normalize` should be :extrema or :centered", |
| 55 | + ), |
| 56 | + ) |
| 57 | + end |
| 58 | + return (attr .- min) / (max - min) |
34 | 59 | end
|
35 | 60 |
|
36 | 61 | # Reduces activation in a pixel with multiple color channels into a single activation
|
37 |
| -struct MaxAbsReducer <: AbstractColorReducer end |
38 |
| -function (::MaxAbsReducer)(img, nchannels) |
39 |
| - nchannels == 1 && return img |
40 |
| - dim = find_color_channel(img, nchannels) |
41 |
| - return dropdims(maximum(abs, img; dims=dim); dims=dim) |
42 |
| -end |
43 |
| - |
44 |
| -struct SumReducer <: AbstractColorReducer end |
45 |
| -function (::SumReducer)(img, nchannels) |
46 |
| - nchannels == 1 && return img |
47 |
| - dim = find_color_channel(img, nchannels) |
48 |
| - return dropdims(sum(img; dims=dim); dims=dim) |
49 |
| -end |
50 |
| - |
51 |
| -function find_color_channel(img, nchannels) |
52 |
| - colordims = findall(size(img) .== nchannels) |
53 |
| - if length(colordims) == 0 |
54 |
| - throw(ArgumentError("No dimension with nchannels=$nchannels color channels found.")) |
55 |
| - elseif length(colordims) > 1 |
56 |
| - throw(ArgumentError("Several dimensions of length $nchannels found.")) |
| 62 | +function _reduce(attr, method::Symbol) |
| 63 | + if method == :maxabs |
| 64 | + return dropdims(maximum(abs, attr; dims=3); dims=3) |
| 65 | + elseif method == :sum |
| 66 | + return dropdims(sum(attr; dims=3); dims=3) |
57 | 67 | end
|
58 |
| - return first(colordims) |
| 68 | + throw( |
| 69 | + ArgumentError( |
| 70 | + "Color channel reducer :$method not supported, `reduce` should be :maxabs or :sum", |
| 71 | + ), |
| 72 | + ) |
59 | 73 | end
|
0 commit comments