Skip to content

Commit 0135250

Browse files
authored
Update XAIBase dependency to v4 (#166)
1 parent 124ab2f commit 0135250

File tree

7 files changed

+30
-29
lines changed

7 files changed

+30
-29
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1515
Distributions = "0.25"
1616
Random = "<0.0.1, 1"
1717
Reexport = "1"
18-
Statistics = "1"
19-
XAIBase = "3"
18+
Statistics = "<0.0.1, 1"
19+
XAIBase = "4"
2020
Zygote = "0.6"
2121
julia = "1.6"

src/ExplainableAI.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module ExplainableAI
22

33
using Reexport
44
@reexport using XAIBase
5+
import XAIBase: call_analyzer
56

67
using Base.Iterators
78
using Distributions: Distribution, Sampleable, Normal

src/gradcam.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ struct GradCAM{F,A} <: AbstractXAIMethod
1919
feature_layers::F
2020
adaptation_layers::A
2121
end
22-
function (analyzer::GradCAM)(input, ns::AbstractOutputSelector)
22+
function call_analyzer(input, analyzer::GradCAM, ns::AbstractOutputSelector; kwargs...)
2323
A = analyzer.feature_layers(input) # feature map
2424
feature_map_size = size(A, 1) * size(A, 2)
2525

2626
# Determine neuron importance αₖᶜ = 1/Z * ∑ᵢ ∑ⱼ ∂yᶜ / ∂Aᵢⱼᵏ
2727
grad, output, output_indices = gradient_wrt_input(analyzer.adaptation_layers, A, ns)
2828
αᶜ = sum(grad; dims=(1, 2)) / feature_map_size
2929
Lᶜ = max.(sum(αᶜ .* A; dims=3), 0)
30-
return Explanation(Lᶜ, output, output_indices, :GradCAM, :cam, nothing)
30+
return Explanation(Lᶜ, input, output, output_indices, :GradCAM, :cam, nothing)
3131
end

src/gradient.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ struct Gradient{M} <: AbstractXAIMethod
1919
Gradient(model) = new{typeof(model)}(model)
2020
end
2121

22-
function (analyzer::Gradient)(input, ns::AbstractOutputSelector)
22+
function call_analyzer(input, analyzer::Gradient, ns::AbstractOutputSelector; kwargs...)
2323
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
24-
return Explanation(grad, output, output_indices, :Gradient, :sensitivity, nothing)
24+
return Explanation(
25+
grad, input, output, output_indices, :Gradient, :sensitivity, nothing
26+
)
2527
end
2628

2729
"""
@@ -35,11 +37,13 @@ struct InputTimesGradient{M} <: AbstractXAIMethod
3537
InputTimesGradient(model) = new{typeof(model)}(model)
3638
end
3739

38-
function (analyzer::InputTimesGradient)(input, ns::AbstractOutputSelector)
40+
function call_analyzer(
41+
input, analyzer::InputTimesGradient, ns::AbstractOutputSelector; kwargs...
42+
)
3943
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
4044
attr = input .* grad
4145
return Explanation(
42-
attr, output, output_indices, :InputTimesGradient, :attribution, nothing
46+
attr, input, output, output_indices, :InputTimesGradient, :attribution, nothing
4347
)
4448
end
4549

src/input_augmentation.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ function NoiseAugmentation(analyzer, n, σ::Real=0.1f0, args...)
103103
return NoiseAugmentation(analyzer, n, Normal(0.0f0, Float32(σ)^2), args...)
104104
end
105105

106-
function (aug::NoiseAugmentation)(input, ns::AbstractOutputSelector)
106+
function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector; kwargs...)
107107
# Regular forward pass of model
108108
output = aug.analyzer.model(input)
109109
output_indices = ns(output)
@@ -116,6 +116,7 @@ function (aug::NoiseAugmentation)(input, ns::AbstractOutputSelector)
116116
# Average explanation
117117
return Explanation(
118118
reduce_augmentation(augmented_expl.val, aug.n),
119+
input,
119120
output,
120121
output_indices,
121122
augmented_expl.analyzer,
@@ -141,8 +142,8 @@ struct InterpolationAugmentation{A<:AbstractXAIMethod} <: AbstractXAIMethod
141142
n::Int
142143
end
143144

144-
function (aug::InterpolationAugmentation)(
145-
input, ns::AbstractOutputSelector; input_ref=zero(input)
145+
function call_analyzer(
146+
input, aug::InterpolationAugmentation, ns::AbstractOutputSelector; input_ref=zero(input)
146147
)
147148
size(input) != size(input_ref) &&
148149
throw(ArgumentError("Input reference size doesn't match input size."))
@@ -161,6 +162,7 @@ function (aug::InterpolationAugmentation)(
161162

162163
return Explanation(
163164
expl,
165+
input,
164166
output,
165167
output_indices,
166168
augmented_expl.analyzer,

test/runtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@ using Aqua
66
using JET
77

88
@testset "ExplainableAI.jl" begin
9-
@info "Testing formalities..."
109
if VERSION >= v"1.10"
10+
@info "Testing formalities..."
1111
@testset "Code formatting" begin
12-
@info "- Testing code formatting with JuliaFormatter..."
12+
@info "- running JuliaFormatter code formatting tests..."
1313
@test JuliaFormatter.format(ExplainableAI; verbose=false, overwrite=false)
1414
end
1515
@testset "Aqua.jl" begin
16-
@info "- Running Aqua.jl tests. These might print warnings from dependencies..."
16+
@info "- running Aqua.jl tests. These might print warnings from dependencies..."
1717
Aqua.test_all(ExplainableAI; ambiguities=false)
1818
end
1919
@testset "JET tests" begin
20-
@info "- Testing type stability with JET..."
20+
@info "- running JET.jl type stability tests..."
2121
JET.test_package(ExplainableAI; target_defined_modules=true)
2222
end
2323
end

test/test_batches.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,12 @@ batchsize = 15
1414

1515
model = Chain(Dense(ins, 15, relu; init=pseudorand), Dense(15, outs, relu; init=pseudorand))
1616

17-
# Input 1 w/o batch dimension
18-
input1_no_bd = rand(MersenneTwister(1), Float32, ins)
1917
# Input 1 with batch dimension
20-
input1_bd = reshape(input1_no_bd, ins, 1)
18+
input1 = rand(MersenneTwister(1), Float32, ins, 1)
2119
# Input 2 with batch dimension
22-
input2_bd = rand(MersenneTwister(2), Float32, ins, 1)
20+
input2 = rand(MersenneTwister(2), Float32, ins, 1)
2321
# Batch containing inputs 1 & 2
24-
input_batch = cat(input1_bd, input2_bd; dims=2)
22+
input_batch = cat(input1, input2; dims=2)
2523

2624
ANALYZERS = Dict(
2725
"Gradient" => Gradient,
@@ -33,25 +31,21 @@ ANALYZERS = Dict(
3331

3432
for (name, method) in ANALYZERS
3533
@testset "$name" begin
36-
# Using `add_batch_dim=true` should result in same explanation
37-
# as input reshaped to have a batch dimension
3834
analyzer = method(model)
39-
expl1_no_bd = analyzer(input1_no_bd; add_batch_dim=true)
40-
analyzer = method(model)
41-
expl1_bd = analyzer(input1_bd)
42-
@test expl1_bd.val expl1_no_bd.val
35+
expl1 = analyzer(input1)
36+
@test expl1.val expl1.val
4337

4438
# Analyzing a batch should have the same result
4539
# as analyzing inputs in batch individually
4640
analyzer = method(model)
47-
expl2_bd = analyzer(input2_bd)
41+
expl2 = analyzer(input2)
4842
analyzer = method(model)
4943
expl_batch = analyzer(input_batch)
50-
@test expl1_bd.val expl_batch.val[:, 1]
44+
@test expl1.val expl_batch.val[:, 1]
5145
if !(analyzer isa NoiseAugmentation)
5246
# NoiseAugmentation methods generate random numbers for the entire batch.
5347
# therefore explanations don't match except for the first input in the batch.
54-
@test expl2_bd.val expl_batch.val[:, 2]
48+
@test expl2.val expl_batch.val[:, 2]
5549
end
5650
end
5751
end

0 commit comments

Comments
 (0)