Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Distributions = "0.25"
Random = "<0.0.1, 1"
Reexport = "1"
Statistics = "1"
XAIBase = "3"
Statistics = "<0.0.1, 1"
XAIBase = "4"
Zygote = "0.6"
julia = "1.6"
1 change: 1 addition & 0 deletions src/ExplainableAI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module ExplainableAI

using Reexport
@reexport using XAIBase
import XAIBase: call_analyzer

using Base.Iterators
using Distributions: Distribution, Sampleable, Normal
Expand Down
4 changes: 2 additions & 2 deletions src/gradcam.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ struct GradCAM{F,A} <: AbstractXAIMethod
feature_layers::F
adaptation_layers::A
end
function (analyzer::GradCAM)(input, ns::AbstractOutputSelector)
function call_analyzer(input, analyzer::GradCAM, ns::AbstractOutputSelector; kwargs...)
A = analyzer.feature_layers(input) # feature map
feature_map_size = size(A, 1) * size(A, 2)

# Determine neuron importance αₖᶜ = 1/Z * ∑ᵢ ∑ⱼ ∂yᶜ / ∂Aᵢⱼᵏ
grad, output, output_indices = gradient_wrt_input(analyzer.adaptation_layers, A, ns)
αᶜ = sum(grad; dims=(1, 2)) / feature_map_size
Lᶜ = max.(sum(αᶜ .* A; dims=3), 0)
return Explanation(Lᶜ, output, output_indices, :GradCAM, :cam, nothing)
return Explanation(Lᶜ, input, output, output_indices, :GradCAM, :cam, nothing)
end
12 changes: 8 additions & 4 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ struct Gradient{M} <: AbstractXAIMethod
Gradient(model) = new{typeof(model)}(model)
end

function (analyzer::Gradient)(input, ns::AbstractOutputSelector)
function call_analyzer(input, analyzer::Gradient, ns::AbstractOutputSelector; kwargs...)
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
return Explanation(grad, output, output_indices, :Gradient, :sensitivity, nothing)
return Explanation(
grad, input, output, output_indices, :Gradient, :sensitivity, nothing
)
end

"""
Expand All @@ -35,11 +37,13 @@ struct InputTimesGradient{M} <: AbstractXAIMethod
InputTimesGradient(model) = new{typeof(model)}(model)
end

function (analyzer::InputTimesGradient)(input, ns::AbstractOutputSelector)
function call_analyzer(
input, analyzer::InputTimesGradient, ns::AbstractOutputSelector; kwargs...
)
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
attr = input .* grad
return Explanation(
attr, output, output_indices, :InputTimesGradient, :attribution, nothing
attr, input, output, output_indices, :InputTimesGradient, :attribution, nothing
)
end

Expand Down
8 changes: 5 additions & 3 deletions src/input_augmentation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ function NoiseAugmentation(analyzer, n, σ::Real=0.1f0, args...)
return NoiseAugmentation(analyzer, n, Normal(0.0f0, Float32(σ)^2), args...)
end

function (aug::NoiseAugmentation)(input, ns::AbstractOutputSelector)
function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector; kwargs...)
# Regular forward pass of model
output = aug.analyzer.model(input)
output_indices = ns(output)
Expand All @@ -116,6 +116,7 @@ function (aug::NoiseAugmentation)(input, ns::AbstractOutputSelector)
# Average explanation
return Explanation(
reduce_augmentation(augmented_expl.val, aug.n),
input,
output,
output_indices,
augmented_expl.analyzer,
Expand All @@ -141,8 +142,8 @@ struct InterpolationAugmentation{A<:AbstractXAIMethod} <: AbstractXAIMethod
n::Int
end

function (aug::InterpolationAugmentation)(
input, ns::AbstractOutputSelector; input_ref=zero(input)
function call_analyzer(
input, aug::InterpolationAugmentation, ns::AbstractOutputSelector; input_ref=zero(input)
)
size(input) != size(input_ref) &&
throw(ArgumentError("Input reference size doesn't match input size."))
Expand All @@ -161,6 +162,7 @@ function (aug::InterpolationAugmentation)(

return Explanation(
expl,
input,
output,
output_indices,
augmented_expl.analyzer,
Expand Down
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ using Aqua
using JET

@testset "ExplainableAI.jl" begin
@info "Testing formalities..."
if VERSION >= v"1.10"
@info "Testing formalities..."
@testset "Code formatting" begin
@info "- Testing code formatting with JuliaFormatter..."
@info "- running JuliaFormatter code formatting tests..."
@test JuliaFormatter.format(ExplainableAI; verbose=false, overwrite=false)
end
@testset "Aqua.jl" begin
@info "- Running Aqua.jl tests. These might print warnings from dependencies..."
@info "- running Aqua.jl tests. These might print warnings from dependencies..."
Aqua.test_all(ExplainableAI; ambiguities=false)
end
@testset "JET tests" begin
@info "- Testing type stability with JET..."
@info "- running JET.jl type stability tests..."
JET.test_package(ExplainableAI; target_defined_modules=true)
end
end
Expand Down
22 changes: 8 additions & 14 deletions test/test_batches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@ batchsize = 15

model = Chain(Dense(ins, 15, relu; init=pseudorand), Dense(15, outs, relu; init=pseudorand))

# Input 1 w/o batch dimension
input1_no_bd = rand(MersenneTwister(1), Float32, ins)
# Input 1 with batch dimension
input1_bd = reshape(input1_no_bd, ins, 1)
input1 = rand(MersenneTwister(1), Float32, ins, 1)
# Input 2 with batch dimension
input2_bd = rand(MersenneTwister(2), Float32, ins, 1)
input2 = rand(MersenneTwister(2), Float32, ins, 1)
# Batch containing inputs 1 & 2
input_batch = cat(input1_bd, input2_bd; dims=2)
input_batch = cat(input1, input2; dims=2)

ANALYZERS = Dict(
"Gradient" => Gradient,
Expand All @@ -33,25 +31,21 @@ ANALYZERS = Dict(

for (name, method) in ANALYZERS
@testset "$name" begin
# Using `add_batch_dim=true` should result in same explanation
# as input reshaped to have a batch dimension
analyzer = method(model)
expl1_no_bd = analyzer(input1_no_bd; add_batch_dim=true)
analyzer = method(model)
expl1_bd = analyzer(input1_bd)
@test expl1_bd.val ≈ expl1_no_bd.val
expl1 = analyzer(input1)
@test expl1.val ≈ expl1.val

# Analyzing a batch should have the same result
# as analyzing inputs in batch individually
analyzer = method(model)
expl2_bd = analyzer(input2_bd)
expl2 = analyzer(input2)
analyzer = method(model)
expl_batch = analyzer(input_batch)
@test expl1_bd.val ≈ expl_batch.val[:, 1]
@test expl1.val ≈ expl_batch.val[:, 1]
if !(analyzer isa NoiseAugmentation)
# NoiseAugmentation methods generate random numbers for the entire batch.
# therefore explanations don't match except for the first input in the batch.
@test expl2_bd.val ≈ expl_batch.val[:, 2]
@test expl2.val ≈ expl_batch.val[:, 2]
end
end
end