1
1
# NOTE: Heatmapping assumes Flux's WHCN convention (width, height, color channels, batch size).
2
2
3
+ const HEATMAPPING_PRESETS = Dict {Symbol,Tuple{ColorScheme,Symbol,Symbol}} (
4
+ # Analyzer => (colorscheme, reduce, normalize)
5
+ :LRP => (ColorSchemes. bwr, :sum , :centered ),
6
+ :InputTimesGradient => (ColorSchemes. bwr, :sum , :centered ), # same as LRP
7
+ :Gradient => (ColorSchemes. grays, :norm , :extrema ),
8
+ )
9
+
3
10
"""
4
- heatmap(expl; kwargs...)
11
+ heatmap(expl::Explanation; kwargs...)
12
+ heatmap(attr::AbstractArray; kwargs...)
13
+ heatmap(input, analyzer::AbstractXAIMethod)
14
+ heatmap(input, analyzer::AbstractXAIMethod, neuron_selection::Int)
5
15
6
16
Visualize explanation.
7
17
Assumes the Flux's WHCN convention (width, height, color channels, batch size).
8
18
9
19
## Keyword arguments
10
- -`cs::ColorScheme`: ColorScheme that is applied. Defaults to `ColorSchemes.bwr`.
20
+ -`cs::ColorScheme`: ColorScheme that is applied.
21
+ When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
22
+ When calling `heatmap` with an array, the default is `ColorSchemes.bwr`.
11
23
-`reduce::Symbol`: How the color channels are reduced to a single number to apply a colorscheme.
12
- Can be either `:sum` or `:maxabs`. `:sum` sums up all color channels for each pixel.
13
- `:maxabs` selects the `maximum(abs, x)` over the color channel in each pixel.
14
- Default is `:sum`.
24
+ The following methods can be selected, which are then applied over the color channels
25
+ for each "pixel" in the attribution:
26
+ - `:sum`: sum up color channels
27
+ - `:norm`: compute 2-norm over the color channels
28
+ - `:maxabs`: compute `maximum(abs, x)` over the color channels in
29
+ When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
30
+ When calling `heatmap` with an array, the default is `:sum`.
15
31
-`normalize::Symbol`: How the color channel reduced heatmap is normalized before the colorscheme is applied.
16
- Can be either `:extrema` or `:centered`. Default for use with colorscheme `bwr` is `:centered`.
32
+ Can be either `:extrema` or `:centered`.
33
+ When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
34
+ When calling `heatmap` with an array, the default for use with the `bwr` colorscheme is `:centered`.
35
+ -`permute::Bool`: Whether to flip W&H input channels. Default is `true`.
36
+
37
+ **Note:** these keyword arguments can't be used when calling `heatmap` with an analyzer.
17
38
"""
39
+
18
40
function heatmap (
19
- expl :: AbstractArray ;
41
+ attr :: AbstractArray ;
20
42
cs:: ColorScheme = ColorSchemes. bwr,
21
43
reduce:: Symbol = :sum ,
22
44
normalize:: Symbol = :centered ,
23
45
permute:: Bool = true ,
24
46
)
25
- _size = size (expl )
47
+ _size = size (attr )
26
48
length (_size) != 4 && throw (
27
49
DomainError (
28
50
_size,
@@ -36,11 +58,27 @@ function heatmap(
36
58
""" heatmap is only applicable to a single attribution, got a batch dimension of $(_size[end ]) .""" ,
37
59
),
38
60
)
39
- # drop batch dim -> reduce color channels -> normalize image -> apply color scheme
40
- img = _normalize (_reduce (dropdims (expl ; dims= 4 ), reduce), normalize)
61
+
62
+ img = _normalize (dropdims ( _reduce (dropdims (attr ; dims= 4 ), reduce); dims = 3 ), normalize)
41
63
permute && (img = permutedims (img))
42
64
return ColorSchemes. get (cs, img)
43
65
end
66
+ # Use HEATMAPPING_PRESETS for default kwargs when dispatching on Explanation
67
+ function heatmap (expl:: Explanation ; permute:: Bool = true , kwargs... )
68
+ _cs, _reduce, _normalize = HEATMAPPING_PRESETS[expl. analyzer]
69
+ return heatmap (
70
+ expl. attribution;
71
+ reduce= get (kwargs, :reduce , _reduce),
72
+ normalize= get (kwargs, :normalize , _normalize),
73
+ cs= get (kwargs, :cs , _cs),
74
+ permute= permute,
75
+ )
76
+ end
77
+ # Analyze & heatmap in one go
78
+ function heatmap (input, analyzer:: AbstractXAIMethod , args... ; kwargs... )
79
+ return heatmap (analyze (input, analyzer, args... ; kwargs... ))
80
+ end
81
+
44
82
45
83
# Normalize activations across pixels
46
84
function _normalize (attr, method:: Symbol )
@@ -58,16 +96,20 @@ function _normalize(attr, method::Symbol)
58
96
return (attr .- min) / (max - min)
59
97
end
60
98
61
- # Reduces activation in a pixel with multiple color channels into a single activation
62
- function _reduce (attr, method:: Symbol )
63
- if method == :maxabs
64
- return dropdims (maximum (abs, attr; dims= 3 ); dims= 3 )
99
+ # Reduce attributions across color channels into a single scalar – assumes WHCN convention
100
+ function _reduce (attr:: T , method:: Symbol ) where {T}
101
+ if size (attr, 3 ) == 1 # nothing need to reduce
102
+ return attr
103
+ elseif method == :maxabs
104
+ return maximum (abs, attr; dims= 3 )
105
+ elseif method == :norm
106
+ return mapslices (norm, attr; dims= 3 ):: T
65
107
elseif method == :sum
66
- return dropdims ( sum (attr; dims = 3 ) ; dims= 3 )
108
+ return sum (attr; dims= 3 )
67
109
end
68
110
throw (
69
111
ArgumentError (
70
- " Color channel reducer :$method not supported, `reduce` should be :maxabs or :sum " ,
112
+ " Color channel reducer :$method not supported, `reduce` should be :maxabs, :sum or :norm " ,
71
113
),
72
114
)
73
115
end
0 commit comments