Skip to content

Commit 3a2065b

Browse files
committed
Update ReadMe with VGG16 examples
1 parent 910438f commit 3a2065b

File tree

1 file changed

+49
-14
lines changed

1 file changed

+49
-14
lines changed

README.md

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,52 @@ This package supports Julia ≥1.6. To install it, open the Julia REPL and run
1717
julia> ]add ExplainableAI
1818
```
1919

20-
⚠️ This package is still in early development, expect breaking changes. ⚠️
21-
2220
## Example
23-
Let's use LRP to explain why an MNIST digit gets classified as a 9 using a small pre-trained LeNet5 model.
24-
If you want to follow along, the model can be found [here][model-bson-url].
21+
Let's use LRP to explain why an image of a castle gets classified as such using a pre-trained VGG16 model from [Metalhead.jl](https://github.com/FluxML/Metalhead.jl):
22+
![][castle]
2523
```julia
2624
using ExplainableAI
2725
using Flux
28-
using MLDatasets
29-
using BSON: @load
26+
using Metalhead
27+
using FileIO
3028

3129
# Load model
32-
@load "model.bson" model
33-
model = strip_softmax(model)
30+
model = VGG(16, pretrain=true).layers
31+
model = strip_softmax(flatten_chain(model))
3432

3533
# Load input
36-
x, _ = MNIST(Float32, :test)[10]
37-
input = reshape(x, 28, 28, 1, :) # reshape to WHCN format
34+
img = load("castle.jpg")
35+
input = preprocess_imagenet(img)
36+
input = reshape(input, 224, 224, 3, :) # reshape to WHCN format
3837

3938
# Run XAI method
4039
analyzer = LRP(model)
41-
expl = analyze(input, analyzer) # or: expl = analyzer(input)
40+
expl = analyze(input, analyzer) # or: expl = analyzer(input)
4241

4342
# Show heatmap
4443
heatmap(expl)
4544

4645
# Or analyze & show heatmap directly
4746
heatmap(input, analyzer)
4847
```
49-
![][heatmap]
48+
49+
We can also get an explanation for the activation of the output neuron corresponding to the "street sign" class by specifying the corresponding output neuron position `920`:
50+
```julia
51+
analyze(input, analyzer, 920) # for explanation
52+
heatmap(input, analyzer, 920) # for heatmap
53+
```
54+
Heatmaps for all implemented analyzers are shown in the following table. Red color indicate regions of positive relevance towards the selected class, whereas regions in blue are of negative relevance.
55+
56+
| **Analyzer** | **Heatmap for class "castle"** |**Heatmap for class "street sign"** |
57+
|:--------------------- |:------------------------------ |:---------------------------------- |
58+
| `LRP` composite | ![][castle-lrp-comp] | ![][streetsign-lrp-comp] |
59+
| `LRP` | ![][castle-lrp] | ![][streetsign-lrp] |
60+
| `InputTimesGradient` | ![][castle-ixg] | ![][streetsign-ixg] |
61+
| `Gradient` | ![][castle-grad] | ![][streetsign-grad] |
62+
| `SmoothGrad` | ![][castle-smoothgrad] | ![][streetsign-smoothgrad] |
63+
| `IntegratedGradients` | ![][castle-intgrad] | ![][streetsign-intgrad] |
64+
65+
The code used to generate these heatmaps can be found [here][asset-code].
5066

5167
## Methods
5268
Currently, the following analyzers are implemented:
@@ -70,6 +86,10 @@ Currently, the following analyzers are implemented:
7086
One of the design goals of ExplainableAI.jl is extensibility.
7187
Individual LRP rules [can be composed][docs-composites] and are easily extended by [custom rules][docs-custom-rules].
7288

89+
## Video demonstration
90+
Check out our [JuliaCon 2022 talk][juliacon-url] for a demonstration of the package.
91+
[![][juliacon-img]][juliacon-url]
92+
7393
## Roadmap
7494
In the future, we would like to include:
7595
- [PatternNet](https://arxiv.org/abs/1705.05598)
@@ -83,7 +103,21 @@ Contributions are welcome!
83103
> Adrian Hill acknowledges support by the Federal Ministry of Education and Research (BMBF) for the Berlin Institute for the Foundations of Learning and Data (BIFOLD) (01IS18037A).
84104
85105
[banner-img]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/banner.png
86-
[heatmap]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/mnist9.png
106+
107+
[asset-code]: https://github.com/adrhill/ExplainableAI.jl/blob/gh-pages/assets/heatmaps/readme_assets.jl
108+
[castle]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/castle.jpg
109+
[castle-lrp]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/castle_LRP.png
110+
[castle-lrp-comp]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/castle_LRP_composite.png
111+
[castle-ixg]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/castle_InputTimesGradient.png
112+
[castle-grad]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/castle_Gradient.png
113+
[castle-smoothgrad]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/castle_SmoothGrad.png
114+
[castle-intgrad]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/castle_IntegratedGradients.png
115+
[streetsign-lrp]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/streetsign_LRP.png
116+
[streetsign-lrp-comp]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/streetsign_LRP_composite.png
117+
[streetsign-ixg]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/streetsign_InputTimesGradient.png
118+
[streetsign-grad]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/streetsign_Gradient.png
119+
[streetsign-smoothgrad]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/streetsign_SmoothGrad.png
120+
[streetsign-intgrad]: https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/streetsign_IntegratedGradients.png
87121

88122
[docs-stab-img]: https://img.shields.io/badge/docs-stable-blue.svg
89123
[docs-stab-url]: https://adrhill.github.io/ExplainableAI.jl/stable
@@ -103,7 +137,8 @@ Contributions are welcome!
103137
[doi-img]: https://zenodo.org/badge/337430397.svg
104138
[doi-url]: https://zenodo.org/badge/latestdoi/337430397
105139

106-
[model-bson-url]: https://github.com/adrhill/ExplainableAI.jl/blob/master/docs/src/model.bson
140+
[juliacon-img]: http://img.youtube.com/vi/p5dg3vdmlvI/0.jpg
141+
[juliacon-url]: https://www.youtube.com/watch?v=p5dg3vdmlvI
107142

108143
[captum-repo]: https://github.com/pytorch/captum
109144
[zennit-repo]: https://github.com/chr5tphr/zennit

0 commit comments

Comments
 (0)