diff --git a/Project.toml b/Project.toml index 499edd4..94083ad 100644 --- a/Project.toml +++ b/Project.toml @@ -15,7 +15,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Distributions = "0.25" Random = "<0.0.1, 1" Reexport = "1" -Statistics = "1" -XAIBase = "3" +Statistics = "<0.0.1, 1" +XAIBase = "4" Zygote = "0.6" julia = "1.6" diff --git a/src/ExplainableAI.jl b/src/ExplainableAI.jl index 19e3149..2b33379 100644 --- a/src/ExplainableAI.jl +++ b/src/ExplainableAI.jl @@ -2,6 +2,7 @@ module ExplainableAI using Reexport @reexport using XAIBase +import XAIBase: call_analyzer using Base.Iterators using Distributions: Distribution, Sampleable, Normal diff --git a/src/gradcam.jl b/src/gradcam.jl index 88f2e61..421ff5f 100644 --- a/src/gradcam.jl +++ b/src/gradcam.jl @@ -19,7 +19,7 @@ struct GradCAM{F,A} <: AbstractXAIMethod feature_layers::F adaptation_layers::A end -function (analyzer::GradCAM)(input, ns::AbstractOutputSelector) +function call_analyzer(input, analyzer::GradCAM, ns::AbstractOutputSelector; kwargs...) A = analyzer.feature_layers(input) # feature map feature_map_size = size(A, 1) * size(A, 2) @@ -27,5 +27,5 @@ function (analyzer::GradCAM)(input, ns::AbstractOutputSelector) grad, output, output_indices = gradient_wrt_input(analyzer.adaptation_layers, A, ns) αᶜ = sum(grad; dims=(1, 2)) / feature_map_size Lᶜ = max.(sum(αᶜ .* A; dims=3), 0) - return Explanation(Lᶜ, output, output_indices, :GradCAM, :cam, nothing) + return Explanation(Lᶜ, input, output, output_indices, :GradCAM, :cam, nothing) end diff --git a/src/gradient.jl b/src/gradient.jl index 409119a..5fea07b 100644 --- a/src/gradient.jl +++ b/src/gradient.jl @@ -19,9 +19,11 @@ struct Gradient{M} <: AbstractXAIMethod Gradient(model) = new{typeof(model)}(model) end -function (analyzer::Gradient)(input, ns::AbstractOutputSelector) +function call_analyzer(input, analyzer::Gradient, ns::AbstractOutputSelector; kwargs...) grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns) - return Explanation(grad, output, output_indices, :Gradient, :sensitivity, nothing) + return Explanation( + grad, input, output, output_indices, :Gradient, :sensitivity, nothing + ) end """ @@ -35,11 +37,13 @@ struct InputTimesGradient{M} <: AbstractXAIMethod InputTimesGradient(model) = new{typeof(model)}(model) end -function (analyzer::InputTimesGradient)(input, ns::AbstractOutputSelector) +function call_analyzer( + input, analyzer::InputTimesGradient, ns::AbstractOutputSelector; kwargs... +) grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns) attr = input .* grad return Explanation( - attr, output, output_indices, :InputTimesGradient, :attribution, nothing + attr, input, output, output_indices, :InputTimesGradient, :attribution, nothing ) end diff --git a/src/input_augmentation.jl b/src/input_augmentation.jl index f0b5c9d..a99edbe 100644 --- a/src/input_augmentation.jl +++ b/src/input_augmentation.jl @@ -103,7 +103,7 @@ function NoiseAugmentation(analyzer, n, σ::Real=0.1f0, args...) return NoiseAugmentation(analyzer, n, Normal(0.0f0, Float32(σ)^2), args...) end -function (aug::NoiseAugmentation)(input, ns::AbstractOutputSelector) +function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector; kwargs...) # Regular forward pass of model output = aug.analyzer.model(input) output_indices = ns(output) @@ -116,6 +116,7 @@ function (aug::NoiseAugmentation)(input, ns::AbstractOutputSelector) # Average explanation return Explanation( reduce_augmentation(augmented_expl.val, aug.n), + input, output, output_indices, augmented_expl.analyzer, @@ -141,8 +142,8 @@ struct InterpolationAugmentation{A<:AbstractXAIMethod} <: AbstractXAIMethod n::Int end -function (aug::InterpolationAugmentation)( - input, ns::AbstractOutputSelector; input_ref=zero(input) +function call_analyzer( + input, aug::InterpolationAugmentation, ns::AbstractOutputSelector; input_ref=zero(input) ) size(input) != size(input_ref) && throw(ArgumentError("Input reference size doesn't match input size.")) @@ -161,6 +162,7 @@ function (aug::InterpolationAugmentation)( return Explanation( expl, + input, output, output_indices, augmented_expl.analyzer, diff --git a/test/runtests.jl b/test/runtests.jl index ab3c7a1..e73db4d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,18 +6,18 @@ using Aqua using JET @testset "ExplainableAI.jl" begin - @info "Testing formalities..." if VERSION >= v"1.10" + @info "Testing formalities..." @testset "Code formatting" begin - @info "- Testing code formatting with JuliaFormatter..." + @info "- running JuliaFormatter code formatting tests..." @test JuliaFormatter.format(ExplainableAI; verbose=false, overwrite=false) end @testset "Aqua.jl" begin - @info "- Running Aqua.jl tests. These might print warnings from dependencies..." + @info "- running Aqua.jl tests. These might print warnings from dependencies..." Aqua.test_all(ExplainableAI; ambiguities=false) end @testset "JET tests" begin - @info "- Testing type stability with JET..." + @info "- running JET.jl type stability tests..." JET.test_package(ExplainableAI; target_defined_modules=true) end end diff --git a/test/test_batches.jl b/test/test_batches.jl index ed34490..473073d 100644 --- a/test/test_batches.jl +++ b/test/test_batches.jl @@ -14,14 +14,12 @@ batchsize = 15 model = Chain(Dense(ins, 15, relu; init=pseudorand), Dense(15, outs, relu; init=pseudorand)) -# Input 1 w/o batch dimension -input1_no_bd = rand(MersenneTwister(1), Float32, ins) # Input 1 with batch dimension -input1_bd = reshape(input1_no_bd, ins, 1) +input1 = rand(MersenneTwister(1), Float32, ins, 1) # Input 2 with batch dimension -input2_bd = rand(MersenneTwister(2), Float32, ins, 1) +input2 = rand(MersenneTwister(2), Float32, ins, 1) # Batch containing inputs 1 & 2 -input_batch = cat(input1_bd, input2_bd; dims=2) +input_batch = cat(input1, input2; dims=2) ANALYZERS = Dict( "Gradient" => Gradient, @@ -33,25 +31,21 @@ ANALYZERS = Dict( for (name, method) in ANALYZERS @testset "$name" begin - # Using `add_batch_dim=true` should result in same explanation - # as input reshaped to have a batch dimension analyzer = method(model) - expl1_no_bd = analyzer(input1_no_bd; add_batch_dim=true) - analyzer = method(model) - expl1_bd = analyzer(input1_bd) - @test expl1_bd.val ≈ expl1_no_bd.val + expl1 = analyzer(input1) + @test expl1.val ≈ expl1.val # Analyzing a batch should have the same result # as analyzing inputs in batch individually analyzer = method(model) - expl2_bd = analyzer(input2_bd) + expl2 = analyzer(input2) analyzer = method(model) expl_batch = analyzer(input_batch) - @test expl1_bd.val ≈ expl_batch.val[:, 1] + @test expl1.val ≈ expl_batch.val[:, 1] if !(analyzer isa NoiseAugmentation) # NoiseAugmentation methods generate random numbers for the entire batch. # therefore explanations don't match except for the first input in the batch. - @test expl2_bd.val ≈ expl_batch.val[:, 2] + @test expl2.val ≈ expl_batch.val[:, 2] end end end