Skip to content

Commit d9ea6eb

Browse files
authored
GPU support for gradient analyzers (#144)
1 parent 78cedfc commit d9ea6eb

File tree

1 file changed

+11
-20
lines changed

1 file changed

+11
-20
lines changed

src/gradient.jl

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
1-
function gradient_wrt_input(model, input, output_indices)
2-
return only(gradient((in) -> model(in)[output_indices], input))
3-
end
1+
function gradient_wrt_input(model, input, ns::AbstractNeuronSelector)
2+
output, back = Zygote.pullback(model, input)
3+
output_indices = ns(output)
44

5-
function gradients_wrt_batch(model, input::AbstractArray{T,N}, output_indices) where {T,N}
6-
# To avoid computing a sparse jacobian, we compute individual gradients
7-
# by calling `gradient_wrt_input` on slices of the input along the batch dimension.
8-
out = similar(input)
9-
inds_before_N = ntuple(Returns(:), N - 1)
10-
for (i, ax) in enumerate(axes(input, N))
11-
view(out, inds_before_N..., ax, :) .= gradient_wrt_input(
12-
model, view(input, inds_before_N..., ax, :), drop_batch_index(output_indices[i])
13-
)
14-
end
15-
return out
5+
# Compute VJP w.r.t. full model output, selecting vector s.t. it masks output neurons
6+
v = zero(output)
7+
v[output_indices] .= 1
8+
grad = only(back(v))
9+
return grad, output, output_indices
1610
end
1711

1812
"""
@@ -25,9 +19,7 @@ struct Gradient{C<:Chain} <: AbstractXAIMethod
2519
Gradient(model::Chain) = new{typeof(model)}(Flux.testmode!(check_output_softmax(model)))
2620
end
2721
function (analyzer::Gradient)(input, ns::AbstractNeuronSelector)
28-
output = analyzer.model(input)
29-
output_indices = ns(output)
30-
grad = gradients_wrt_batch(analyzer.model, input, output_indices)
22+
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
3123
return Explanation(grad, output, output_indices, :Gradient, nothing)
3224
end
3325

@@ -44,9 +36,8 @@ struct InputTimesGradient{C<:Chain} <: AbstractXAIMethod
4436
end
4537
end
4638
function (analyzer::InputTimesGradient)(input, ns::AbstractNeuronSelector)
47-
output = analyzer.model(input)
48-
output_indices = ns(output)
49-
attr = input .* gradients_wrt_batch(analyzer.model, input, output_indices)
39+
grad, output, output_indices = gradient_wrt_input(analyzer.model, input, ns)
40+
attr = input .* grad
5041
return Explanation(attr, output, output_indices, :InputTimesGradient, nothing)
5142
end
5243

0 commit comments

Comments
 (0)