Skip to content

Commit 295266c

Browse files
authored
Move core interface into XAIBase.jl package (#154)
* Add XAIBase dependency * Add heatmapping preset field to Explanations * Update `heatmap` kwarg `cs` to `colorscheme` * Remove tests that were moved to XAIBase * Update compat entries of dependencies and test dependencies * Update documentation
1 parent 9a98f97 commit 295266c

18 files changed

+52
-407
lines changed

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ExplainableAI"
22
uuid = "4f1bc3e1-d60d-4ed0-9367-9bdff9846d3b"
33
authors = ["Adrian Hill <[email protected]>"]
4-
version = "0.6.3"
4+
version = "1.0.0-DEV"
55

66
[deps]
77
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
@@ -12,8 +12,10 @@ ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
1212
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1313
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
1414
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
15+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1516
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1617
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
18+
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
1719
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1820

1921
[weakdeps]
@@ -29,6 +31,11 @@ Flux = "0.13, 0.14"
2931
ImageCore = "0.9, 0.10"
3032
ImageTransformations = "0.9, 0.10"
3133
MacroTools = "0.5"
34+
Markdown = "1"
35+
Random = "1"
36+
Reexport = "1"
37+
Statistics = "1"
3238
Tullio = "0.3"
39+
XAIBase = "1.2"
3340
Zygote = "0.6"
3441
julia = "1.6"

docs/make.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
using ExplainableAI
21
using Documenter
32
using Literate
3+
using ExplainableAI
44

55
LITERATE_DIR = joinpath(@__DIR__, "src/literate")
66
OUT_DIR = joinpath(@__DIR__, "src/generated")
@@ -22,7 +22,7 @@ end
2222
convert_literate(LITERATE_DIR, OUT_DIR)
2323

2424
makedocs(;
25-
modules=[ExplainableAI],
25+
modules=[XAIBase, ExplainableAI],
2626
authors="Adrian Hill",
2727
sitename="ExplainableAI.jl",
2828
format=Documenter.HTML(; prettyurls=get(ENV, "CI", "false") == "true", assets=String[]),

docs/src/literate/example.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,22 @@ analyzer = LRP(model)
4848
expl = analyze(input, analyzer);
4949

5050
# The return value `expl` is of type [`Explanation`](@ref) and bundles the following data:
51-
# * `expl.val`: the numerical output of the analyzer, e.g. an attribution or gradient
52-
# * `expl.output`: the model output for the given analyzer input
53-
# * `expl.neuron_selection`: the neuron index used for the explanation
54-
# * `expl.analyzer`: a symbol corresponding the used analyzer, e.g. `:LRP`
55-
# * `expl.extras`: an optional named tuple that can be used by analyzers
56-
# to return additional information.
51+
# * `expl.val`: numerical output of the analyzer, e.g. an attribution or gradient
52+
# * `expl.output`: model output for the given analyzer input
53+
# * `expl.output_selection`: index of the output used for the explanation
54+
# * `expl.analyzer`: symbol corresponding the used analyzer, e.g. `:Gradient` or `:LRP`
55+
# * `expl.heatmap`: symbol indicating a preset heatmapping style,
56+
# e.g. `:attibution`, `:sensitivity` or `:cam`
57+
# * `expl.extras`: optional named tuple that can be used by analyzers
58+
# to return additional information.
5759
#
5860
# We used an LRP analyzer, so `expl.analyzer` is `:LRP`.
5961
expl.analyzer
6062

6163
# By default, the explanation is computed for the maximally activated output neuron.
6264
# Since our digit is a 9 and Julia's indexing is 1-based,
6365
# the output neuron at index `10` of our trained model is maximally activated.
64-
expl.neuron_selection
66+
expl.output_selection
6567

6668
# Finally, we obtain the result of the analyzer in form of an array.
6769
expl.val

docs/src/literate/heatmapping.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ heatmap(input, analyzer)
3939
using ColorSchemes
4040

4141
expl = analyze(input, analyzer)
42-
heatmap(expl; cs=ColorSchemes.jet)
42+
heatmap(expl; colorscheme=:jet)
4343
#-
44-
heatmap(expl; cs=ColorSchemes.inferno)
44+
heatmap(expl; colorscheme=:inferno)
4545

4646
# Refer to the [ColorSchemes.jl catalogue](https://juliagraphics.github.io/ColorSchemes.jl/stable/basics/)
4747
# for a gallery of available color schemes.
@@ -84,9 +84,9 @@ heatmap(expl; rangescale=:centered)
8484
heatmap(expl; rangescale=:extrema)
8585
# However, for the `inferno` color scheme, which is not centered around zero,
8686
# `:extrema` leads to a heatmap with higher contrast.
87-
heatmap(expl; rangescale=:centered, cs=ColorSchemes.inferno)
87+
heatmap(expl; rangescale=:centered, colorscheme=:inferno)
8888
#-
89-
heatmap(expl; rangescale=:extrema, cs=ColorSchemes.inferno)
89+
heatmap(expl; rangescale=:extrema, colorscheme=:inferno)
9090

9191
# For the full list of `heatmap` keyword arguments, refer to the [`heatmap`](@ref) documentation.
9292

@@ -110,11 +110,13 @@ mosaic(heatmaps; nrow=10)
110110
#
111111
# If this bevahior is not desired,
112112
# `heatmap` can be called with the keyword-argument `process_batch=true`:
113-
heatmaps = heatmap(batch, analyzer; process_batch=true)
113+
expl = analyze(batch, analyzer)
114+
heatmaps = heatmap(expl; process_batch=true)
114115
mosaic(heatmaps; nrow=10)
115116

116117
# This can be useful when comparing heatmaps for fixed output neurons:
117-
heatmaps = heatmap(batch, analyzer, 7; process_batch=true) # heatmaps for digit "6"
118+
expl = analyze(batch, analyzer, 7) # explain digit "6"
119+
heatmaps = heatmap(expl; process_batch=true)
118120
mosaic(heatmaps; nrow=10)
119121

120122
#md # !!! note "Output type consistency"

src/ExplainableAI.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
module ExplainableAI
22

3+
using Reexport
4+
@reexport using XAIBase
5+
36
using Base.Iterators
47
using MacroTools: @forward
58
using Distributions: Distribution, Sampleable, Normal
@@ -15,8 +18,6 @@ using ColorSchemes
1518

1619
include("compat.jl")
1720
include("bibliography.jl")
18-
include("neuron_selection.jl")
19-
include("analyze_api.jl")
2021
include("flux_types.jl")
2122
include("flux_layer_utils.jl")
2223
include("flux_chain_utils.jl")
@@ -31,7 +32,6 @@ include("lrp/lrp.jl")
3132
include("lrp/show.jl")
3233
include("lrp/composite_presets.jl") # uses lrp/show.jl
3334
include("lrp/crp.jl")
34-
include("heatmap.jl")
3535
include("preprocessing.jl")
3636
export analyze
3737

src/analyze_api.jl

Lines changed: 0 additions & 77 deletions
This file was deleted.

src/gradient.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222

2323
function (analyzer::Gradient)(input, ns::AbstractNeuronSelector)
2424
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
25-
return Explanation(grad, output, output_indices, :Gradient, nothing)
25+
return Explanation(grad, output, output_indices, :Gradient, :sensitivity, nothing)
2626
end
2727

2828
"""
@@ -42,7 +42,7 @@ end
4242
function (analyzer::InputTimesGradient)(input, ns::AbstractNeuronSelector)
4343
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
4444
attr = input .* grad
45-
return Explanation(attr, output, output_indices, :InputTimesGradient, nothing)
45+
return Explanation(attr, output, output_indices, :InputTimesGradient, :attribution, nothing)
4646
end
4747

4848
"""

src/heatmap.jl

Lines changed: 0 additions & 114 deletions
This file was deleted.

src/input_augmentation.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
"""
2+
AugmentationSelector(index)
3+
4+
Neuron selector that passes through an augmented neuron selection.
5+
"""
6+
struct AugmentationSelector{I} <: AbstractNeuronSelector
7+
indices::I
8+
end
9+
(s::AugmentationSelector)(out) = s.indices
10+
111
"""
212
augment_batch_dim(input, n)
313
@@ -109,6 +119,7 @@ function (aug::NoiseAugmentation)(input, ns::AbstractNeuronSelector)
109119
output,
110120
output_indices,
111121
augmented_expl.analyzer,
122+
augmented_expl.heatmap,
112123
nothing,
113124
)
114125
end
@@ -148,7 +159,7 @@ function (aug::InterpolationAugmentation)(
148159
# Average gradients and compute explanation
149160
expl = (input - input_ref) .* reduce_augmentation(augmented_expl.val, aug.n)
150161

151-
return Explanation(expl, output, output_indices, augmented_expl.analyzer, nothing)
162+
return Explanation(expl, output, output_indices, augmented_expl.analyzer, augmented_expl.heatmap, nothing)
152163
end
153164

154165
"""

src/lrp/crp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ function (crp::CRP)(input::AbstractArray{T,N}, ns::AbstractNeuronSelector) where
9494
end
9595
end
9696
end
97-
return Explanation(R_return, last(as), ns(last(as)), :CRP, nothing)
97+
return Explanation(R_return, last(as), ns(last(as)), :CRP, :attribution, nothing)
9898
end
9999

100100
#===================#

0 commit comments

Comments
 (0)