Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,31 @@ Let's explain why an image of a castle is classified as such by a vision model:

```julia
using ExplainableAI

# Load model and input
model = ... # load classifier model
input = ... # input in batch-dimension-last format
using VisionHeatmaps # visualization of explanations as heatmaps
using Zygote # load autodiff backend for gradient-based methods
using Flux, Metalhead # pre-trained vision models in Flux
using DataAugmentation # input preprocessing
using HTTP, FileIO, ImageIO # load image from URL
using ImageInTerminal # show heatmap in terminal

# Load & prepare model
model = VGG(16, pretrain=true)

# Load input
url = HTTP.URI("https://raw.githubusercontent.com/Julia-XAI/ExplainableAI.jl/gh-pages/assets/heatmaps/castle.jpg")
img = load(url)

# Preprocess input
mean = (0.485f0, 0.456f0, 0.406f0)
std = (0.229f0, 0.224f0, 0.225f0)
tfm = CenterResizeCrop((224, 224)) |> ImageToTensor() |> Normalize(mean, std)
input = apply(tfm, Image(img)) # apply DataAugmentation transform
input = reshape(input.data, 224, 224, 3, :) # unpack data and add batch dimension

# Run XAI method
analyzer = SmoothGrad(model)
expl = analyze(input, analyzer) # or: analyzer(input)

# Show heatmap
heatmap(expl)

# Or analyze & show heatmap directly
heatmap(input, analyzer)
expl = analyze(input, analyzer) # or: expl = analyzer(input)
heatmap(expl) # show heatmap using VisionHeatmaps.jl
```

By default, explanations are computed for the class with the highest activation.
Expand Down
Loading