Skip to content

Commit 5974a12

Browse files
committed
GPU friendly implementation of SmoothGrad
1 parent a9a493b commit 5974a12

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

src/ExplainableAI.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import XAIBase: call_analyzer
66

77
using Base.Iterators
88
using Distributions: Distribution, Sampleable, Normal
9-
using Random: AbstractRNG, GLOBAL_RNG
9+
using Random: AbstractRNG, GLOBAL_RNG, rand!
10+
using ProgressMeter: Progress, next!
1011

1112
# Automatic differentiation
1213
using ADTypes: AbstractADType, AutoZygote

src/input_augmentation.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,19 @@ function call_analyzer(input, aug::NoiseAugmentation, ns::AbstractOutputSelector
5252
output_selector = AugmentationSelector(output_indices)
5353

5454
p = Progress(aug.n; desc="Sampling NoiseAugmentation...", showspeed=aug.show_progress)
55+
5556
# First augmentation
56-
input_aug = similar(input)
57-
input_aug = sample_noise!(input_aug, input, aug)
58-
expl_aug = aug.analyzer(input_aug, output_selector)
57+
noisy_input = similar(input)
58+
noisy_input = sample_noise!(noisy_input, input, aug)
59+
expl_aug = aug.analyzer(noisy_input, output_selector)
5960
sum_val = expl_aug.val
6061
next!(p)
6162

6263
# Further augmentations
6364
for _ in 2:(aug.n)
64-
input_aug = sample_noise!(input_aug, input, aug)
65-
expl_aug = aug.analyzer(input_aug, output_selector)
66-
sum_val += expl_aug.val
65+
noisy_input = sample_noise!(noisy_input, input, aug)
66+
expl_aug = aug.analyzer(noisy_input, output_selector)
67+
sum_val .+= expl_aug.val
6768
next!(p)
6869
end
6970

@@ -78,7 +79,9 @@ end
7879
function sample_noise!(
7980
out::A, input::A, aug::NoiseAugmentation
8081
) where {T,A<:AbstractArray{T}}
81-
out .= input .+ rand(aug.rng, aug.distribution, size(input))
82+
out = rand!(aug.rng, aug.distribution, out)
83+
out .+= input
84+
return out
8285
end
8386

8487
"""

0 commit comments

Comments
 (0)