Skip to content

Commit 032b617

Browse files
committed
Test IntegratedGradients as well
1 parent 5974a12 commit 032b617

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/input_augmentation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ function call_analyzer(
123123
# Further augmentations
124124
input_delta = (input - input_ref) / (aug.n - 1)
125125
for _ in 1:(aug.n)
126-
input_aug += input_delta
126+
input_aug .+= input_delta
127127
expl_aug = aug.analyzer(input_aug, output_selector)
128-
sum_val += expl_aug.val
128+
sum_val .+= expl_aug.val
129129
end
130130

131131
# Average gradients and compute explanation

test/test_gpu.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ model_gpu = device(model)
2020
input_gpu = device(input)
2121
@test_nowarn model_gpu(input_gpu)
2222

23-
analyzer_types = (Gradient, SmoothGrad, InputTimesGradient)
23+
analyzer_types = (Gradient, SmoothGrad, InputTimesGradient, IntegratedGradients)
2424

2525
@testset "Run analyzer (CPU)" begin
2626
@testset "$A" for A in analyzer_types

0 commit comments

Comments
 (0)