1
1
# NOTE: Heatmapping assumes Flux's WHCN convention (width, height, color channels, batch size).
2
2
3
3
const HEATMAPPING_PRESETS = Dict {Symbol,Tuple{ColorScheme,Symbol,Symbol}} (
4
- # Analyzer => (colorscheme, reduce, normalize )
4
+ # Analyzer => (colorscheme, reduce, rangescale )
5
5
:LRP => (ColorSchemes. bwr, :sum , :centered ), # attribution
6
6
:InputTimesGradient => (ColorSchemes. bwr, :sum , :centered ), # attribution
7
7
:Gradient => (ColorSchemes. grays, :norm , :extrema ), # gradient
@@ -29,7 +29,7 @@ Assumes Flux's WHCN convention (width, height, color channels, batch size).
29
29
- `:maxabs`: compute `maximum(abs, x)` over the color channels in
30
30
When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
31
31
When calling `heatmap` with an array, the default is `:sum`.
32
- - `normalize ::Symbol`: How the color channel reduced heatmap is normalized before the colorscheme is applied.
32
+ - `rangescale ::Symbol`: How the color channel reduced heatmap is normalized before the colorscheme is applied.
33
33
Can be either `:extrema` or `:centered`.
34
34
When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
35
35
When calling `heatmap` with an array, the default for use with the `bwr` colorscheme is `:centered`.
@@ -43,7 +43,7 @@ function heatmap(
43
43
attr:: AbstractArray{T,N} ;
44
44
cs:: ColorScheme = ColorSchemes. bwr,
45
45
reduce:: Symbol = :sum ,
46
- normalize :: Symbol = :centered ,
46
+ rangescale :: Symbol = :centered ,
47
47
permute:: Bool = true ,
48
48
unpack_singleton:: Bool = true ,
49
49
) where {T,N}
@@ -55,18 +55,18 @@ function heatmap(
55
55
),
56
56
)
57
57
if unpack_singleton && size (attr, 4 ) == 1
58
- return _heatmap (attr[:, :, :, 1 ], cs, reduce, normalize , permute)
58
+ return _heatmap (attr[:, :, :, 1 ], cs, reduce, rangescale , permute)
59
59
end
60
- return map (a -> _heatmap (a, cs, reduce, normalize , permute), eachslice (attr; dims= 4 ))
60
+ return map (a -> _heatmap (a, cs, reduce, rangescale , permute), eachslice (attr; dims= 4 ))
61
61
end
62
62
63
63
# Use HEATMAPPING_PRESETS for default kwargs when dispatching on Explanation
64
64
function heatmap (expl:: Explanation ; permute:: Bool = true , kwargs... )
65
- _cs, _reduce, _normalize = HEATMAPPING_PRESETS[expl. analyzer]
65
+ _cs, _reduce, _rangescale = HEATMAPPING_PRESETS[expl. analyzer]
66
66
return heatmap (
67
67
expl. attribution;
68
68
reduce= get (kwargs, :reduce , _reduce),
69
- normalize = get (kwargs, :normalize , _normalize ),
69
+ rangescale = get (kwargs, :rangescale , _rangescale ),
70
70
cs= get (kwargs, :cs , _cs),
71
71
permute= permute,
72
72
)
@@ -81,28 +81,12 @@ function _heatmap(
81
81
attr:: AbstractArray{T,3} ,
82
82
cs:: ColorScheme ,
83
83
reduce:: Symbol ,
84
- normalize :: Symbol ,
84
+ rangescale :: Symbol ,
85
85
permute:: Bool ,
86
86
) where {T<: Real }
87
- img = _normalize ( dropdims (_reduce (attr, reduce); dims= 3 ), normalize )
87
+ img = dropdims (_reduce (attr, reduce); dims= 3 )
88
88
permute && (img = permutedims (img))
89
- return ColorSchemes. get (cs, img)
90
- end
91
-
92
- # Normalize activations across pixels
93
- function _normalize (attr, method:: Symbol )
94
- if method == :centered
95
- min, max = (- 1 , 1 ) .* maximum (abs, attr)
96
- elseif method == :extrema
97
- min, max = extrema (attr)
98
- else
99
- throw (
100
- ArgumentError (
101
- " Color scheme normalizer :$method not supported, `normalize` should be :extrema or :centered" ,
102
- ),
103
- )
104
- end
105
- return (attr .- min) / (max - min)
89
+ return ColorSchemes. get (cs, img, rangescale)
106
90
end
107
91
108
92
# Reduce attributions across color channels into a single scalar – assumes WHCN convention
0 commit comments