Skip to content

Commit c778bba

Browse files
authored
Add Explanation type and dispatch heatmap on it (#36)
* adds wrapper `Explanation` around analysis that contains meta-data such as the used analyzer * adds `:norm` reducer * dispatch `heatmap` on analyzer symbol in Explanation * add option to directly call `heatmap` with analyzer
1 parent b4fb888 commit c778bba

21 files changed

+214
-41
lines changed

docs/literate/example.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,22 @@ input = permutedims(input, (2,1,3))[:,:,:,:] * 255; # flip X/Y axes, add batch d
4141
# We can now select an analyzer of our choice
4242
# and call [`analyze`](@ref) to get an explanation `expl`:
4343
analyzer = LRPZero(model)
44-
expl, out = analyze(input, analyzer);
44+
expl = analyze(input, analyzer);
4545

4646
# Finally, we can visualize the explanation through heatmapping:
4747
heatmap(expl)
4848

49+
# Or do both in one combined step:
50+
heatmap(input, analyzer)
51+
4952
#md # !!! tip "Neuron selection"
5053
#md # To get an explanation with respect to a specific output neuron (e.g. class 42) call
5154
#md # ```julia
52-
#md # expl, out = analyze(img, analyzer, 42)
55+
#md # expl = analyze(img, analyzer, 42)
56+
#md # ```
57+
#md # or using `heatmap`
58+
#md # ```julia
59+
#md # heatmap(img, analyzer, 42)
5360
#md # ```
5461
#
5562
# Currently, the following analyzers are implemented:
@@ -75,13 +82,11 @@ model = flatten_model(model)
7582
#
7683
# Now we set a rule for each layer
7784
rules = [
78-
ZBoxRule(),
79-
repeat([GammaRule()], 15)...,
80-
repeat([ZeroRule()], length(model) - 16)...
85+
ZBoxRule(), repeat([GammaRule()], 15)..., repeat([ZeroRule()], length(model) - 16)...
8186
]
8287
# and define a custom LRP analyzer:
8388
analyzer = LRP(model, rules)
84-
expl, out = analyze(input, analyzer)
89+
expl = analyze(input, analyzer)
8590
heatmap(expl)
8691

8792
# ## Custom rules
@@ -98,7 +103,7 @@ end
98103

99104
# We can directly use this rule to make an analyzer!
100105
analyzer = LRP(model, MyCustomLRPRule())
101-
expl, out = analyze(input, analyzer)
106+
expl = analyze(input, analyzer)
102107
heatmap(expl)
103108

104109
#md # !!! tip "Pull requests welcome"

src/analyze_api.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Otherwise, the output neuron with the highest activation is automatically chosen
1313
function analyze(
1414
input::AbstractArray{<:Real},
1515
method::AbstractXAIMethod,
16-
neuron_selection::Integer,
16+
neuron_selection::Integer;
1717
kwargs...,
1818
)
1919
return method(input, IndexNS(neuron_selection); kwargs...)
@@ -22,3 +22,13 @@ end
2222
function analyze(input::AbstractArray{<:Real}, method::AbstractXAIMethod; kwargs...)
2323
return method(input, MaxActivationNS(); kwargs...)
2424
end
25+
26+
# Explanations and outputs are returned in a wrapper.
27+
# Metadata such as the analyzer allows dispatching on functions like `heatmap`.
28+
struct Explanation{A,O,L}
29+
attribution::A
30+
output::O
31+
neuron_selection::Int
32+
analyzer::Symbol
33+
layerwise_relevances::L
34+
end

src/gradient.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ end
1010
function (analyzer::Gradient)(input, ns::AbstractNeuronSelector)
1111
output = analyzer.model(input)
1212
output_neuron = ns(output)
13-
expl = gradient((in) -> analyzer.model(in)[output_neuron], input)[1]
14-
return expl, output
13+
attr = gradient((in) -> analyzer.model(in)[output_neuron], input)[1]
14+
return Explanation(attr, output, output_neuron, :Gradient, Nothing)
1515
end
1616

1717
"""
@@ -29,6 +29,6 @@ end
2929
function (analyzer::InputTimesGradient)(input, ns::AbstractNeuronSelector)
3030
output = analyzer.model(input)
3131
output_neuron = ns(output)
32-
expl = input .* gradient((in) -> analyzer.model(in)[output_neuron], input)[1]
33-
return expl, output
32+
attr = input .* gradient((in) -> analyzer.model(in)[output_neuron], input)[1]
33+
return Explanation(attr, output, output_neuron, :InputTimesGradient, Nothing)
3434
end

src/heatmap.jl

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,50 @@
11
# NOTE: Heatmapping assumes Flux's WHCN convention (width, height, color channels, batch size).
22

3+
const HEATMAPPING_PRESETS = Dict{Symbol,Tuple{ColorScheme,Symbol,Symbol}}(
4+
# Analyzer => (colorscheme, reduce, normalize)
5+
:LRP => (ColorSchemes.bwr, :sum, :centered),
6+
:InputTimesGradient => (ColorSchemes.bwr, :sum, :centered), # same as LRP
7+
:Gradient => (ColorSchemes.grays, :norm, :extrema),
8+
)
9+
310
"""
4-
heatmap(expl; kwargs...)
11+
heatmap(expl::Explanation; kwargs...)
12+
heatmap(attr::AbstractArray; kwargs...)
13+
heatmap(input, analyzer::AbstractXAIMethod)
14+
heatmap(input, analyzer::AbstractXAIMethod, neuron_selection::Int)
515
616
Visualize explanation.
717
Assumes the Flux's WHCN convention (width, height, color channels, batch size).
818
919
## Keyword arguments
10-
-`cs::ColorScheme`: ColorScheme that is applied. Defaults to `ColorSchemes.bwr`.
20+
-`cs::ColorScheme`: ColorScheme that is applied.
21+
When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
22+
When calling `heatmap` with an array, the default is `ColorSchemes.bwr`.
1123
-`reduce::Symbol`: How the color channels are reduced to a single number to apply a colorscheme.
12-
Can be either `:sum` or `:maxabs`. `:sum` sums up all color channels for each pixel.
13-
`:maxabs` selects the `maximum(abs, x)` over the color channel in each pixel.
14-
Default is `:sum`.
24+
The following methods can be selected, which are then applied over the color channels
25+
for each "pixel" in the attribution:
26+
- `:sum`: sum up color channels
27+
- `:norm`: compute 2-norm over the color channels
28+
- `:maxabs`: compute `maximum(abs, x)` over the color channels in
29+
When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
30+
When calling `heatmap` with an array, the default is `:sum`.
1531
-`normalize::Symbol`: How the color channel reduced heatmap is normalized before the colorscheme is applied.
16-
Can be either `:extrema` or `:centered`. Default for use with colorscheme `bwr` is `:centered`.
32+
Can be either `:extrema` or `:centered`.
33+
When calling `heatmap` with an `Explanation` or analyzer, the method default is selected.
34+
When calling `heatmap` with an array, the default for use with the `bwr` colorscheme is `:centered`.
35+
-`permute::Bool`: Whether to flip W&H input channels. Default is `true`.
36+
37+
**Note:** these keyword arguments can't be used when calling `heatmap` with an analyzer.
1738
"""
39+
1840
function heatmap(
19-
expl::AbstractArray;
41+
attr::AbstractArray;
2042
cs::ColorScheme=ColorSchemes.bwr,
2143
reduce::Symbol=:sum,
2244
normalize::Symbol=:centered,
2345
permute::Bool=true,
2446
)
25-
_size = size(expl)
47+
_size = size(attr)
2648
length(_size) != 4 && throw(
2749
DomainError(
2850
_size,
@@ -36,11 +58,27 @@ function heatmap(
3658
"""heatmap is only applicable to a single attribution, got a batch dimension of $(_size[end]).""",
3759
),
3860
)
39-
# drop batch dim -> reduce color channels -> normalize image -> apply color scheme
40-
img = _normalize(_reduce(dropdims(expl; dims=4), reduce), normalize)
61+
62+
img = _normalize(dropdims(_reduce(dropdims(attr; dims=4), reduce); dims=3), normalize)
4163
permute && (img = permutedims(img))
4264
return ColorSchemes.get(cs, img)
4365
end
66+
# Use HEATMAPPING_PRESETS for default kwargs when dispatching on Explanation
67+
function heatmap(expl::Explanation; permute::Bool=true, kwargs...)
68+
_cs, _reduce, _normalize = HEATMAPPING_PRESETS[expl.analyzer]
69+
return heatmap(
70+
expl.attribution;
71+
reduce=get(kwargs, :reduce, _reduce),
72+
normalize=get(kwargs, :normalize, _normalize),
73+
cs=get(kwargs, :cs, _cs),
74+
permute=permute,
75+
)
76+
end
77+
# Analyze & heatmap in one go
78+
function heatmap(input, analyzer::AbstractXAIMethod, args...; kwargs...)
79+
return heatmap(analyze(input, analyzer, args...; kwargs...))
80+
end
81+
4482

4583
# Normalize activations across pixels
4684
function _normalize(attr, method::Symbol)
@@ -58,16 +96,20 @@ function _normalize(attr, method::Symbol)
5896
return (attr .- min) / (max - min)
5997
end
6098

61-
# Reduces activation in a pixel with multiple color channels into a single activation
62-
function _reduce(attr, method::Symbol)
63-
if method == :maxabs
64-
return dropdims(maximum(abs, attr; dims=3); dims=3)
99+
# Reduce attributions across color channels into a single scalar – assumes WHCN convention
100+
function _reduce(attr::T, method::Symbol) where {T}
101+
if size(attr, 3) == 1 # nothing need to reduce
102+
return attr
103+
elseif method == :maxabs
104+
return maximum(abs, attr; dims=3)
105+
elseif method == :norm
106+
return mapslices(norm, attr; dims=3)::T
65107
elseif method == :sum
66-
return dropdims(sum(attr; dims=3); dims=3)
108+
return sum(attr; dims=3)
67109
end
68110
throw(
69111
ArgumentError(
70-
"Color channel reducer :$method not supported, `reduce` should be :maxabs or :sum",
112+
"Color channel reducer :$method not supported, `reduce` should be :maxabs, :sum or :norm",
71113
),
72114
)
73115
end

src/lrp.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,11 @@ function (analyzer::LRP)(input, ns::AbstractNeuronSelector; layerwise_relevances
6666
rels[i] .= lrp(rule, layers[i], acts[i], rels[i + 1])
6767
end
6868

69-
if layerwise_relevances
70-
return rels, acts
71-
end
72-
73-
return rels[1], acts[end] # expl, output
69+
return Explanation(
70+
first(rels),
71+
last(acts),
72+
output_neuron,
73+
:LRP,
74+
ifelse(layerwise_relevances, rels, Nothing),
75+
)
7476
end
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
▀▀
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
▀▀
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀
2+
▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀
3+
▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀
4+
▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀
5+
▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀
6+
▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀
7+
▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀
8+
▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀
9+
▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀
10+
▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀
11+
▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀
12+
▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀
13+
▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀
14+
▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀
15+
▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀

0 commit comments

Comments
 (0)