1
- function gradient_wrt_input (model, input, output_indices )
2
- return only ( gradient ((in) -> model (in)[output_indices] , input) )
3
- end
1
+ function gradient_wrt_input (model, input, ns :: AbstractNeuronSelector )
2
+ output, back = Zygote . pullback (model , input)
3
+ output_indices = ns (output)
4
4
5
- function gradients_wrt_batch (model, input:: AbstractArray{T,N} , output_indices) where {T,N}
6
- # To avoid computing a sparse jacobian, we compute individual gradients
7
- # by calling `gradient_wrt_input` on slices of the input along the batch dimension.
8
- out = similar (input)
9
- inds_before_N = ntuple (Returns (:), N - 1 )
10
- for (i, ax) in enumerate (axes (input, N))
11
- view (out, inds_before_N... , ax, :) .= gradient_wrt_input (
12
- model, view (input, inds_before_N... , ax, :), drop_batch_index (output_indices[i])
13
- )
14
- end
15
- return out
5
+ # Compute VJP w.r.t. full model output, selecting vector s.t. it masks output neurons
6
+ v = zero (output)
7
+ v[output_indices] .= 1
8
+ grad = only (back (v))
9
+ return grad, output, output_indices
16
10
end
17
11
18
12
"""
@@ -25,9 +19,7 @@ struct Gradient{C<:Chain} <: AbstractXAIMethod
25
19
Gradient (model:: Chain ) = new {typeof(model)} (Flux. testmode! (check_output_softmax (model)))
26
20
end
27
21
function (analyzer:: Gradient )(input, ns:: AbstractNeuronSelector )
28
- output = analyzer. model (input)
29
- output_indices = ns (output)
30
- grad = gradients_wrt_batch (analyzer. model, input, output_indices)
22
+ grad, output, output_indices = gradient_wrt_input (analyzer. model, input, ns)
31
23
return Explanation (grad, output, output_indices, :Gradient , nothing )
32
24
end
33
25
@@ -44,9 +36,8 @@ struct InputTimesGradient{C<:Chain} <: AbstractXAIMethod
44
36
end
45
37
end
46
38
function (analyzer:: InputTimesGradient )(input, ns:: AbstractNeuronSelector )
47
- output = analyzer. model (input)
48
- output_indices = ns (output)
49
- attr = input .* gradients_wrt_batch (analyzer. model, input, output_indices)
39
+ grad, output, output_indices = gradient_wrt_input (analyzer. model, input, ns)
40
+ attr = input .* grad
50
41
return Explanation (attr, output, output_indices, :InputTimesGradient , nothing )
51
42
end
52
43
0 commit comments