Skip to content

Commit 5dd3dfd

Browse files
authored
Allow Gradient analyzers on non-Flux models (#150)
* Allow Gradient analyzers on non-Flux models * Fix typo in `BATCHDIM_MISSING` error
1 parent bfaf500 commit 5dd3dfd

File tree

3 files changed

+17
-15
lines changed

3 files changed

+17
-15
lines changed

src/analyze_api.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ abstract type AbstractXAIMethod end
44

55
const BATCHDIM_MISSING = ArgumentError(
66
"""The input is a 1D vector and therefore missing the required batch dimension.
7-
Call `analyze` with the keyword argument `add_batch_dim=false`."""
7+
Call analyze with the keyword argument add_batch_dim=true."""
88
)
99

1010
"""
@@ -46,16 +46,14 @@ end
4646

4747
# lower-level call to method
4848
function _analyze(
49-
input::AbstractArray{T,N},
49+
input::AbstractArray,
5050
method::AbstractXAIMethod,
5151
sel::AbstractNeuronSelector;
5252
add_batch_dim::Bool=false,
5353
kwargs...,
54-
) where {T<:Real,N}
55-
if add_batch_dim
56-
return method(batch_dim_view(input), sel; kwargs...)
57-
end
58-
N < 2 && throw(BATCHDIM_MISSING)
54+
)
55+
add_batch_dim && (input = batch_dim_view(input))
56+
ndims(input) < 2 && throw(BATCHDIM_MISSING)
5957
return method(input, sel; kwargs...)
6058
end
6159

src/gradient.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ end
1414
1515
Analyze model by calculating the gradient of a neuron activation with respect to the input.
1616
"""
17-
struct Gradient{C<:Chain} <: AbstractXAIMethod
18-
model::C
17+
struct Gradient{M} <: AbstractXAIMethod
18+
model::M
19+
Gradient(model) = new{typeof(model)}(model)
1920
Gradient(model::Chain) = new{typeof(model)}(Flux.testmode!(check_output_softmax(model)))
2021
end
22+
2123
function (analyzer::Gradient)(input, ns::AbstractNeuronSelector)
2224
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
2325
return Explanation(grad, output, output_indices, :Gradient, nothing)
@@ -29,12 +31,14 @@ end
2931
Analyze model by calculating the gradient of a neuron activation with respect to the input.
3032
This gradient is then multiplied element-wise with the input.
3133
"""
32-
struct InputTimesGradient{C<:Chain} <: AbstractXAIMethod
33-
model::C
34+
struct InputTimesGradient{M} <: AbstractXAIMethod
35+
model::M
36+
InputTimesGradient(model) = new{typeof(model)}(model)
3437
function InputTimesGradient(model::Chain)
35-
return new{typeof(model)}(Flux.testmode!(check_output_softmax(model)))
38+
new{typeof(model)}(Flux.testmode!(check_output_softmax(model)))
3639
end
3740
end
41+
3842
function (analyzer::InputTimesGradient)(input, ns::AbstractNeuronSelector)
3943
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
4044
attr = input .* grad

src/heatmap.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ See also [`analyze`](@ref).
4242
instead of computing it individually for each sample. Defaults to `false`.
4343
"""
4444
function heatmap(
45-
val::AbstractArray{T,N};
45+
val::AbstractArray;
4646
cs::ColorScheme=ColorSchemes.seismic,
4747
reduce::Symbol=:sum,
4848
rangescale::Symbol=:centered,
4949
permute::Bool=true,
5050
unpack_singleton::Bool=true,
5151
process_batch::Bool=false,
52-
) where {T,N}
53-
N != 4 && throw(
52+
)
53+
ndims(val) != 4 && throw(
5454
ArgumentError(
5555
"heatmap assumes Flux's WHCN convention (width, height, color channels, batch size) for the input.
5656
Please reshape your explanation to match this format if your model doesn't adhere to this convention.",

0 commit comments

Comments
 (0)