@@ -13,7 +13,7 @@ function lrp!(Rₖ, rule::R, layer::L, aₖ, Rₖ₊₁) where {R<:AbstractLRPRu
13
13
check_compat (rule, layer)
14
14
reset! = get_layer_resetter (rule, layer)
15
15
modify_layer! (rule, layer)
16
- ãₖ₊₁, pullback = Zygote. pullback (layer, modify_input (rule, aₖ))
16
+ ãₖ₊₁, pullback = Zygote. pullback (preactivation ( layer) , modify_input (rule, aₖ))
17
17
Rₖ .= aₖ .* only (pullback (Rₖ₊₁ ./ modify_denominator (rule, ãₖ₊₁)))
18
18
reset! ()
19
19
return nothing
@@ -180,13 +180,15 @@ modify_input(::WSquareRule, input) = ones_like(input)
180
180
"""
181
181
FlatRule()
182
182
183
- LRP-Flat rule. Similar to the [`WSquareRule`](@ref), but with all parameters set to one.
183
+ LRP-Flat rule. Similar to the [`WSquareRule`](@ref), but with all weights set to one
184
+ and all bias terms set to zero.
184
185
185
186
# References
186
187
- $REF_LAPUSCHKIN_CLEVER_HANS
187
188
"""
188
189
struct FlatRule <: AbstractLRPRule end
189
- modify_param! (:: FlatRule , p) = fill! (p, 1 )
190
+ modify_weight! (:: FlatRule , w) = fill! (w, 1 )
191
+ modify_bias! (:: FlatRule , b) = fill! (b, 0 )
190
192
modify_input (:: FlatRule , input) = ones_like (input)
191
193
192
194
"""
@@ -232,20 +234,25 @@ function lrp!(Rₖ, rule::ZBoxRule, layer::L, aₖ, Rₖ₊₁) where {L}
232
234
h = zbox_input (aₖ, rule. high)
233
235
234
236
# Compute pullback for W, b
235
- aₖ₊₁, pullback = Zygote. pullback (layer, aₖ)
237
+ aₖ₊₁, pullback = Zygote. pullback (preactivation ( layer) , aₖ)
236
238
237
239
# Compute pullback for W⁺, b⁺
238
240
modify_layer! (Val (:keep_positive ), layer)
239
- aₖ₊₁⁺, pullback⁺ = Zygote. pullback (layer, l)
241
+ aₖ₊₁⁺, pullback⁺ = Zygote. pullback (preactivation ( layer) , l)
240
242
reset! ()
241
243
242
244
# Compute pullback for W⁻, b⁻
243
245
modify_layer! (Val (:keep_negative ), layer)
244
- aₖ₊₁⁻, pullback⁻ = Zygote. pullback (layer, h)
245
- reset! ()
246
+ aₖ₊₁⁻, pullback⁻ = Zygote. pullback (preactivation (layer), h)
246
247
248
+ # Evaluate pullbacks
247
249
y = Rₖ₊₁ ./ modify_denominator (rule, aₖ₊₁ - aₖ₊₁⁺ - aₖ₊₁⁻)
248
- Rₖ .= aₖ .* only (pullback (y)) - l .* only (pullback⁺ (y)) - h .* only (pullback⁻ (y))
250
+ Rₖ .= - h .* only (pullback⁻ (y))
251
+ reset! () # re-modify mutated pullback
252
+ Rₖ .+ = aₖ .* only (pullback (y))
253
+ modify_layer! (Val (:keep_positive ), layer) # re-modify mutated pullback
254
+ Rₖ .- = l .* only (pullback⁺ (y))
255
+ reset! ()
249
256
return nothing
250
257
end
251
258
@@ -278,7 +285,7 @@ struct AlphaBetaRule{T} <: AbstractLRPRule
278
285
alpha < 0 && throw (ArgumentError (" Parameter `alpha` must be ≥0." ))
279
286
beta < 0 && throw (ArgumentError (" Parameter `beta` must be ≥0." ))
280
287
! isone (alpha - beta) && throw (ArgumentError (" `alpha - beta` must be equal one." ))
281
- return new {Float32 } (alpha, beta)
288
+ return new {eltype(alpha) } (alpha, beta)
282
289
end
283
290
end
284
291
@@ -290,24 +297,33 @@ function lrp!(Rₖ, rule::AlphaBetaRule, layer::L, aₖ, Rₖ₊₁) where {L}
290
297
aₖ⁺ = keep_positive (aₖ)
291
298
aₖ⁻ = keep_negative (aₖ)
292
299
300
+ # α: positive contributions
301
+ modify_layer! (Val (:keep_negative_zero_bias ), layer)
302
+ aₖ₊₁ᵅ⁻, pullbackᵅ⁻ = Zygote. pullback (preactivation (layer), aₖ⁻)
303
+ reset! ()
293
304
modify_layer! (Val (:keep_positive ), layer)
294
- out_1, pullback_1 = Zygote. pullback (layer, aₖ⁺)
305
+ aₖ₊₁ᵅ⁺, pullbackᵅ⁺ = Zygote. pullback (preactivation (layer), aₖ⁺)
306
+ # evaluate pullbacks
307
+ yᵅ = Rₖ₊₁ ./ modify_denominator (rule, aₖ₊₁ᵅ⁺ + aₖ₊₁ᵅ⁻)
308
+ Rₖ .= rule. α .* aₖ⁺ .* only (pullbackᵅ⁺ (yᵅ))
295
309
reset! ()
296
- modify_layer! (Val (:keep_negative_zero_bias ), layer)
297
- out_2, pullback_2 = Zygote. pullback (layer, aₖ⁻)
310
+ modify_layer! (Val (:keep_negative_zero_bias ), layer) # re-modify mutated pullback
311
+ Rₖ .+ = rule. α .* aₖ⁻ .* only (pullbackᵅ⁻ (yᵅ))
312
+ reset! ()
313
+
314
+ # β: Negative contributions
315
+ modify_layer! (Val (:keep_positive_zero_bias ), layer)
316
+ aₖ₊₁ᵝ⁻, pullbackᵝ⁻ = Zygote. pullback (preactivation (layer), aₖ⁻) #
298
317
reset! ()
299
318
modify_layer! (Val (:keep_negative ), layer)
300
- out_3, pullback_3 = Zygote. pullback (layer, aₖ⁺)
319
+ aₖ₊₁ᵝ⁺, pullbackᵝ⁺ = Zygote. pullback (preactivation (layer), aₖ⁺)
320
+ # evaluate pullbacks
321
+ yᵝ = Rₖ₊₁ ./ modify_denominator (rule, aₖ₊₁ᵝ⁺ + aₖ₊₁ᵝ⁻)
322
+ Rₖ .- = rule. β .* aₖ⁺ .* only (pullbackᵝ⁺ (yᵝ))
301
323
reset! ()
302
- modify_layer! (Val (:keep_positive_zero_bias ), layer)
303
- out_4, pullback_4 = Zygote . pullback (layer, aₖ⁻)
324
+ modify_layer! (Val (:keep_positive_zero_bias ), layer) # re-modify mutated pullback
325
+ Rₖ .- = rule . β .* aₖ⁻ .* only ( pullbackᵝ⁻ (yᵝ) )
304
326
reset! ()
305
-
306
- y_α = Rₖ₊₁ ./ modify_denominator (rule, out_1 + out_2)
307
- y_β = Rₖ₊₁ ./ modify_denominator (rule, out_3 + out_4)
308
- Rₖ .=
309
- rule. α .* (aₖ⁺ .* only (pullback_1 (y_α)) + aₖ⁻ .* only (pullback_2 (y_α))) .-
310
- rule. β .* (aₖ⁺ .* only (pullback_3 (y_β)) + aₖ⁻ .* only (pullback_4 (y_β)))
311
327
return nothing
312
328
end
313
329
@@ -331,15 +347,20 @@ function lrp!(Rₖ, rule::ZPlusRule, layer::L, aₖ, Rₖ₊₁) where {L}
331
347
aₖ⁺ = keep_positive (aₖ)
332
348
aₖ⁻ = keep_negative (aₖ)
333
349
350
+ # Linearize around positive & negative activations (aₖ⁺, aₖ⁻)
334
351
modify_layer! (Val (:keep_positive ), layer)
335
- out_1, pullback_1 = Zygote. pullback (layer, aₖ⁺)
352
+ aₖ₊₁⁺, pullback⁺ = Zygote. pullback (layer, aₖ⁺)
336
353
reset! ()
337
354
modify_layer! (Val (:keep_negative_zero_bias ), layer)
338
- out_2, pullback_2 = Zygote. pullback (layer, aₖ⁻)
339
- reset! ()
355
+ aₖ₊₁⁻, pullback⁻ = Zygote. pullback (layer, aₖ⁻)
340
356
341
- y_α = Rₖ₊₁ ./ modify_denominator (rule, out_1 + out_2)
342
- Rₖ .= aₖ⁺ .* only (pullback_1 (y_α)) + aₖ⁻ .* only (pullback_2 (y_α))
357
+ # Evaluate pullbacks
358
+ y = Rₖ₊₁ ./ modify_denominator (rule, aₖ₊₁⁺ + aₖ₊₁⁻)
359
+ Rₖ .= aₖ⁻ .* only (pullback⁻ (y))
360
+ reset! ()
361
+ modify_layer! (Val (:keep_positive ), layer) # re-modify mutated pullback
362
+ Rₖ .+ = aₖ⁺ .* only (pullback⁺ (y))
363
+ reset! ()
343
364
return nothing
344
365
end
345
366
@@ -355,8 +376,9 @@ for R in (ZeroRule, EpsilonRule, GammaRule, WSquareRule, FlatRule)
355
376
@eval function lrp! (Rₖ, rule:: $R , layer:: Dense , aₖ, Rₖ₊₁)
356
377
reset! = get_layer_resetter (rule, layer)
357
378
modify_layer! (rule, layer)
358
- ãₖ₊₁ = modify_denominator (rule, layer (modify_input (rule, aₖ)))
359
- @tullio Rₖ[j, b] = aₖ[j, b] * layer. weight[k, j] * Rₖ₊₁[k, b] / ãₖ₊₁[k, b]
379
+ ãₖ = modify_input (rule, aₖ)
380
+ ãₖ₊₁ = modify_denominator (rule, preactivation (layer, ãₖ))
381
+ @tullio Rₖ[j, b] = ãₖ[j, b] * layer. weight[k, j] * Rₖ₊₁[k, b] / ãₖ₊₁[k, b]
360
382
reset! ()
361
383
return nothing
362
384
end
0 commit comments