Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
version:
- 'lts'
- '1'
# - 'pre'
- 'pre'
steps:
- uses: actions/checkout@v6
- uses: julia-actions/setup-julia@v2
Expand Down Expand Up @@ -74,3 +74,17 @@ jobs:
using Documenter: doctest
using ExplainableAI
doctest(ExplainableAI)'

runic:
name: Runic formatting
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
# - uses: julia-actions/setup-julia@v2
# with:
# version: '1'
# - uses: julia-actions/cache@v2
- uses: fredrikekre/runic-action@v1
with:
version: '1'

4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# ExplainableAI.jl

## Version `v0.10.4-DEV`
- ![Maintenance][badge-maintenance] Switch from JuliaFormatter to Runic, update JET ([#188])

## Version `v0.10.3`
- ![Maintenance][badge-maintenance] Update dependencies

Expand Down Expand Up @@ -234,6 +237,7 @@ Performance improvements:
[VisionHeatmaps]: https://julia-xai.github.io/XAIDocs/VisionHeatmaps/stable/
[TextHeatmaps]: https://julia-xai.github.io/XAIDocs/TextHeatmaps/stable/

[#188]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/188
[#184]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/184
[#183]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/183
[#180]: https://github.com/Julia-XAI/ExplainableAI.jl/pull/180
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ExplainableAI"
uuid = "4f1bc3e1-d60d-4ed0-9367-9bdff9846d3b"
authors = ["Adrian Hill <gh@adrianhill.de>"]
version = "0.10.3"
version = "0.10.4-DEV"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
8 changes: 1 addition & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ___
| Documentation | [![][docs-stab-img]][docs-stab-url] [![][docs-dev-img]][docs-dev-url] [![][changelog-img]][changelog-url] |
| Build Status | [![][ci-img]][ci-url] [![][codecov-img]][codecov-url] |
| Testing | [![Aqua][aqua-img]][aqua-url] [![JET][jet-img]][jet-url] |
| Code Style | [![Code Style: Blue][blue-img]][blue-url] [![ColPrac][colprac-img]][colprac-url] |
| Code Style | [![Code Style: Runic](https://img.shields.io/badge/code_style-%E1%9A%B1%E1%9A%A2%E1%9A%BE%E1%9B%81%E1%9A%B2-black)](https://github.com/fredrikekre/Runic.jl) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) |
| Citation | [![][doi-img]][doi-url] |

Explainable AI in Julia.
Expand Down Expand Up @@ -182,12 +182,6 @@ Contributions are welcome!
[jet-img]: https://img.shields.io/badge/%F0%9F%9B%A9%EF%B8%8F_tested_with-JET.jl-233f9a
[jet-url]: https://github.com/aviatesk/JET.jl


[blue-img]: https://img.shields.io/badge/code%20style-blue-4495d1.svg
[blue-url]: https://github.com/invenia/BlueStyle
[colprac-img]: https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet
[colprac-url]: https://github.com/SciML/ColPrac

[docs-composites]: https://julia-xai.github.io/ExplainableAI.jl/stable/generated/lrp/composites/
[docs-custom-rules]: https://julia-xai.github.io/ExplainableAI.jl/stable/generated/lrp/custom_rules/

Expand Down
14 changes: 7 additions & 7 deletions benchmark/bench_jogger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ input = rand(T, input_size)

model = Chain(
Chain(
Conv((3, 3), 3 => 8, relu; pad=1),
Conv((3, 3), 8 => 8, relu; pad=1),
Conv((3, 3), 3 => 8, relu; pad = 1),
Conv((3, 3), 8 => 8, relu; pad = 1),
MaxPool((2, 2)),
Conv((3, 3), 8 => 16, relu; pad=1),
Conv((3, 3), 16 => 16, relu; pad=1),
Conv((3, 3), 8 => 16, relu; pad = 1),
Conv((3, 3), 16 => 16, relu; pad = 1),
MaxPool((2, 2)),
),
Chain(
Expand All @@ -29,9 +29,9 @@ Flux.testmode!(model, true)

# Use one representative algorithm of each type
METHODS = Dict(
"Gradient" => Gradient,
"InputTimesGradient" => InputTimesGradient,
"SmoothGrad" => model -> SmoothGrad(model, 5),
"Gradient" => Gradient,
"InputTimesGradient" => InputTimesGradient,
"SmoothGrad" => model -> SmoothGrad(model, 5),
"IntegratedGradients" => model -> IntegratedGradients(model, 5),
)

Expand Down
27 changes: 14 additions & 13 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,23 @@ function convert_literate(dir_in, dir_out)
if isdir(path)
convert_literate(path, joinpath(dir_out, p))
else # isfile
Literate.markdown(path, dir_out; documenter=true) # Markdown for Documenter.jl
Literate.markdown(path, dir_out; documenter = true) # Markdown for Documenter.jl
Literate.notebook(path, dir_out) # .ipynb notebook
Literate.script(path, dir_out) # .jl script
end
end
return nothing
end
convert_literate(LITERATE_DIR, OUT_DIR)

makedocs(;
modules=[XAIBase, ExplainableAI],
authors="Adrian Hill",
sitename="ExplainableAI.jl",
format=Documenter.HTML(;
prettyurls=get(ENV, "CI", "false") == "true",
size_threshold=300_000,
assets=String[],
modules = [XAIBase, ExplainableAI],
authors = "Adrian Hill",
sitename = "ExplainableAI.jl",
format = Documenter.HTML(;
prettyurls = get(ENV, "CI", "false") == "true",
size_threshold = 300_000,
assets = String[],
),
#! format: off
pages=[
Expand All @@ -40,13 +41,13 @@ makedocs(;
"API Reference" => "api.md",
],
#! format: on
linkcheck=true,
linkcheck_ignore=[
linkcheck = true,
linkcheck_ignore = [
r"https://link.springer.com/chapter/10.1007/978-3-030-28954-6_10",
r"https://www.nature.com/articles/s42256-023-00711-8",
],
warnonly=[:missing_docs],
checkdocs=:exports, # only check docstrings in API reference if they are exported
warnonly = [:missing_docs],
checkdocs = :exports, # only check docstrings in API reference if they are exported
)

deploydocs(; repo="github.com/Julia-XAI/ExplainableAI.jl")
deploydocs(; repo = "github.com/Julia-XAI/ExplainableAI.jl")
2 changes: 1 addition & 1 deletion docs/src/literate/augmentations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ heatmap(input, analyzer)
matrix_of_ones = ones(Float32, size(input))

analyzer = InterpolationAugmentation(Gradient(model), 50)
expl = analyzer(input; input_ref=matrix_of_ones)
expl = analyzer(input; input_ref = matrix_of_ones)
heatmap(expl)

# Once again, `InterpolationAugmentation` can be combined with any analyzer type from the Julia-XAI ecosystem,
Expand Down
2 changes: 1 addition & 1 deletion docs/src/literate/example.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,5 @@ analyzer = InputTimesGradient(model)
heatmap(input, analyzer)

# Using [VisionHeatmaps.jl](https://julia-xai.github.io/XAIDocs/VisionHeatmaps/stable/),
# heatmaps can be heavily customized.
# heatmaps can be heavily customized.
# Check out the [heatmapping documentation](https://julia-xai.github.io/XAIDocs/VisionHeatmaps/stable/) for more information.
12 changes: 6 additions & 6 deletions src/gradcam.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ GradCAM is compatible with a wide variety of CNN model-families.
# References
- $REF_SELVARAJU_GRADCAM
"""
struct GradCAM{F,A,B<:AbstractADType} <: AbstractXAIMethod
struct GradCAM{F, A, B <: AbstractADType} <: AbstractXAIMethod
feature_layers::F
adaptation_layers::A
backend::B

function GradCAM(
feature_layers::F, adaptation_layers::A, backend::B=DEFAULT_AD_BACKEND
) where {F,A,B<:AbstractADType}
new{F,A,B}(feature_layers, adaptation_layers, backend)
feature_layers::F, adaptation_layers::A, backend::B = DEFAULT_AD_BACKEND
) where {F, A, B <: AbstractADType}
return new{F, A, B}(feature_layers, adaptation_layers, backend)
end
end
function call_analyzer(input, analyzer::GradCAM, ns::AbstractOutputSelector; kwargs...)
Expand All @@ -34,7 +34,7 @@ function call_analyzer(input, analyzer::GradCAM, ns::AbstractOutputSelector; kwa
grad, output, output_indices = gradient_wrt_input(
analyzer.adaptation_layers, A, ns, analyzer.backend
)
αᶜ = sum(grad; dims=(1, 2)) / feature_map_size
Lᶜ = max.(sum(αᶜ .* A; dims=3), 0)
αᶜ = sum(grad; dims = (1, 2)) / feature_map_size
Lᶜ = max.(sum(αᶜ .* A; dims = 3), 0)
return Explanation(Lᶜ, input, output, output_indices, :GradCAM, :cam, nothing)
end
30 changes: 15 additions & 15 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ function forward_with_output_selection(model, input, selector::AbstractOutputSel
end

function gradient_wrt_input(
model, input, output_selector::AbstractOutputSelector, backend::AbstractADType
)
model, input, output_selector::AbstractOutputSelector, backend::AbstractADType
)
output = model(input)
return gradient_wrt_input(model, input, output, output_selector, backend)
end

function gradient_wrt_input(
model, input, output, output_selector::AbstractOutputSelector, backend::AbstractADType
)
model, input, output, output_selector::AbstractOutputSelector, backend::AbstractADType
)
output_selection = output_selector(output)
dy = zero(output)
dy[output_selection] .= 1
Expand All @@ -28,12 +28,12 @@ end

Analyze model by calculating the gradient of a neuron activation with respect to the input.
"""
struct Gradient{M,B<:AbstractADType} <: AbstractXAIMethod
struct Gradient{M, B <: AbstractADType} <: AbstractXAIMethod
model::M
backend::B

function Gradient(model::M, backend::B=DEFAULT_AD_BACKEND) where {M,B<:AbstractADType}
new{M,B}(model, backend)
function Gradient(model::M, backend::B = DEFAULT_AD_BACKEND) where {M, B <: AbstractADType}
return new{M, B}(model, backend)
end
end

Expand All @@ -52,20 +52,20 @@ end
Analyze model by calculating the gradient of a neuron activation with respect to the input.
This gradient is then multiplied element-wise with the input.
"""
struct InputTimesGradient{M,B<:AbstractADType} <: AbstractXAIMethod
struct InputTimesGradient{M, B <: AbstractADType} <: AbstractXAIMethod
model::M
backend::B

function InputTimesGradient(
model::M, backend::B=DEFAULT_AD_BACKEND
) where {M,B<:AbstractADType}
new{M,B}(model, backend)
model::M, backend::B = DEFAULT_AD_BACKEND
) where {M, B <: AbstractADType}
return new{M, B}(model, backend)
end
end

function call_analyzer(
input, analyzer::InputTimesGradient, ns::AbstractOutputSelector; kwargs...
)
input, analyzer::InputTimesGradient, ns::AbstractOutputSelector; kwargs...
)
grad, output, output_indices = gradient_wrt_input(
analyzer.model, input, ns, analyzer.backend
)
Expand All @@ -91,7 +91,7 @@ e.g. `std = 0.1 * (maximum(input) - minimum(input))`.
# References
- $REF_SMILKOV_SMOOTHGRAD
"""
SmoothGrad(model, n=50, args...) = NoiseAugmentation(Gradient(model), n, args...)
SmoothGrad(model, n = 50, args...) = NoiseAugmentation(Gradient(model), n, args...)

"""
IntegratedGradients(analyzer, [n=50])
Expand All @@ -102,4 +102,4 @@ Analyze model by using the Integrated Gradients method.
# References
- $REF_SUNDARARAJAN_AXIOMATIC
"""
IntegratedGradients(model, n=50) = InterpolationAugmentation(Gradient(model), n)
IntegratedGradients(model, n = 50) = InterpolationAugmentation(Gradient(model), n)
28 changes: 14 additions & 14 deletions src/input_augmentation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,24 @@ e.g. `std = 0.1 * (maximum(input) - minimum(input))`.
Defaults to `GLOBAL_RNG`.
- `show_progress:Bool`: Show progress meter while sampling augmentations. Defaults to `true`.
"""
struct NoiseAugmentation{A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG} <:
AbstractXAIMethod
struct NoiseAugmentation{A <: AbstractXAIMethod, D <: Sampleable, R <: AbstractRNG} <:
AbstractXAIMethod
analyzer::A
n::Int
distribution::D
rng::R
show_progress::Bool

function NoiseAugmentation(
analyzer::A, n::Int, distribution::D, rng::R=GLOBAL_RNG, show_progress=true
) where {A<:AbstractXAIMethod,D<:Sampleable,R<:AbstractRNG}
analyzer::A, n::Int, distribution::D, rng::R = GLOBAL_RNG, show_progress = true
) where {A <: AbstractXAIMethod, D <: Sampleable, R <: AbstractRNG}
n < 1 && throw(ArgumentError("Number of samples `n` needs to be larger than zero."))
return new{A,D,R}(analyzer, n, distribution, rng, show_progress)
return new{A, D, R}(analyzer, n, distribution, rng, show_progress)
end
end
function NoiseAugmentation(
analyzer, n::Int, std::T=1.0f0, rng=GLOBAL_RNG, show_progress=true
) where {T<:Real}
analyzer, n::Int, std::T = 1.0f0, rng = GLOBAL_RNG, show_progress = true
) where {T <: Real}
distribution = Normal(zero(T), std^2)
return NoiseAugmentation(analyzer, n, distribution, rng, show_progress)
end
Expand All @@ -52,7 +52,7 @@ function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector
output_indices = ns(output)
output_selector = AugmentationSelector(output_indices)

p = Progress(aug.n; desc="Sampling NoiseAugmentation...", enabled=aug.show_progress)
p = Progress(aug.n; desc = "Sampling NoiseAugmentation...", enabled = aug.show_progress)

# First augmentation
noisy_input = similar(input)
Expand All @@ -78,8 +78,8 @@ function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector
end

function sample_noise!(
out::A, input::A, aug::NoiseAugmentation
) where {T,A<:AbstractArray{T}}
out::A, input::A, aug::NoiseAugmentation
) where {T, A <: AbstractArray{T}}
out = rand!(aug.rng, aug.distribution, out)
out .+= input
return out
Expand All @@ -93,11 +93,11 @@ between the input and a reference input (typically `zero(input)`).
The gradients w.r.t. this augmented input are then averaged and multiplied with the
difference between the input and the reference input.
"""
struct InterpolationAugmentation{A<:AbstractXAIMethod} <: AbstractXAIMethod
struct InterpolationAugmentation{A <: AbstractXAIMethod} <: AbstractXAIMethod
analyzer::A
n::Int

function InterpolationAugmentation(analyzer::A, n::Int) where {A<:AbstractXAIMethod}
function InterpolationAugmentation(analyzer::A, n::Int) where {A <: AbstractXAIMethod}
n < 2 && throw(
ArgumentError("Number of interpolation steps `n` needs to be larger than one."),
)
Expand All @@ -106,8 +106,8 @@ struct InterpolationAugmentation{A<:AbstractXAIMethod} <: AbstractXAIMethod
end

function call_analyzer(
input, aug::InterpolationAugmentation, ns::AbstractOutputSelector; input_ref=zero(input)
)
input, aug::InterpolationAugmentation, ns::AbstractOutputSelector; input_ref = zero(input)
)
size(input) != size(input_ref) &&
throw(ArgumentError("Input reference size doesn't match input size."))

Expand Down
5 changes: 4 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
PkgJogger = "10150987-6cc1-4b76-abee-b1c1cbd91c01"
Expand All @@ -16,3 +15,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Aqua = "0.8"
JET = "0.9, 0.11"
Loading
Loading