@@ -5,15 +5,15 @@ function forward_with_output_selection(model, input, selector::AbstractOutputSel
55end
66
77function gradient_wrt_input(
8- model, input, output_selector:: AbstractOutputSelector , backend:: AbstractADType
9- )
8+ model, input, output_selector:: AbstractOutputSelector , backend:: AbstractADType
9+ )
1010 output = model(input)
1111 return gradient_wrt_input(model, input, output, output_selector, backend)
1212end
1313
1414function gradient_wrt_input(
15- model, input, output, output_selector:: AbstractOutputSelector , backend:: AbstractADType
16- )
15+ model, input, output, output_selector:: AbstractOutputSelector , backend:: AbstractADType
16+ )
1717 output_selection = output_selector(output)
1818 dy = zero(output)
1919 dy[output_selection] .= 1
2828
2929Analyze model by calculating the gradient of a neuron activation with respect to the input.
3030"""
31- struct Gradient{M,B <: AbstractADType } <: AbstractXAIMethod
31+ struct Gradient{M, B <: AbstractADType } <: AbstractXAIMethod
3232 model:: M
3333 backend:: B
3434
35- function Gradient(model:: M , backend:: B = DEFAULT_AD_BACKEND) where {M,B <: AbstractADType }
36- new{M,B}(model, backend)
35+ function Gradient(model:: M , backend:: B = DEFAULT_AD_BACKEND) where {M, B <: AbstractADType }
36+ return new{M, B}(model, backend)
3737 end
3838end
3939
5252Analyze model by calculating the gradient of a neuron activation with respect to the input.
5353This gradient is then multiplied element-wise with the input.
5454"""
55- struct InputTimesGradient{M,B <: AbstractADType } <: AbstractXAIMethod
55+ struct InputTimesGradient{M, B <: AbstractADType } <: AbstractXAIMethod
5656 model:: M
5757 backend:: B
5858
5959 function InputTimesGradient(
60- model:: M , backend:: B = DEFAULT_AD_BACKEND
61- ) where {M,B <: AbstractADType }
62- new{M,B}(model, backend)
60+ model:: M , backend:: B = DEFAULT_AD_BACKEND
61+ ) where {M, B <: AbstractADType }
62+ return new{M, B}(model, backend)
6363 end
6464end
6565
6666function call_analyzer(
67- input, analyzer:: InputTimesGradient , ns:: AbstractOutputSelector ; kwargs...
68- )
67+ input, analyzer:: InputTimesGradient , ns:: AbstractOutputSelector ; kwargs...
68+ )
6969 grad, output, output_indices = gradient_wrt_input(
7070 analyzer. model, input, ns, analyzer. backend
7171 )
@@ -91,7 +91,7 @@ e.g. `std = 0.1 * (maximum(input) - minimum(input))`.
9191# References
9292- $REF_SMILKOV_SMOOTHGRAD
9393"""
94- SmoothGrad(model, n= 50 , args... ) = NoiseAugmentation(Gradient(model), n, args... )
94+ SmoothGrad(model, n = 50 , args... ) = NoiseAugmentation(Gradient(model), n, args... )
9595
9696"""
9797 IntegratedGradients(analyzer, [n=50])
@@ -102,4 +102,4 @@ Analyze model by using the Integrated Gradients method.
102102# References
103103- $REF_SUNDARARAJAN_AXIOMATIC
104104"""
105- IntegratedGradients(model, n= 50 ) = InterpolationAugmentation(Gradient(model), n)
105+ IntegratedGradients(model, n = 50 ) = InterpolationAugmentation(Gradient(model), n)
0 commit comments