@@ -47,21 +47,28 @@ otherwise throw an `ArgumentError`.
47
47
check_compat (rule, layer) = require_weight_and_bias (rule, layer)
48
48
49
49
"""
50
- modify_layer!(rule, layer)
50
+ modify_layer!(rule, layer; ignore_bias=false )
51
51
52
52
In-place modify layer parameters by calling `modify_param!` before computing relevance
53
53
propagation.
54
54
55
55
## Note
56
56
When implementing a custom `modify_layer!` function, `modify_param!` will not be called.
57
57
"""
58
- function modify_layer! (rule:: R , layer:: L ) where {R,L}
59
- if has_weight_and_bias (layer)
60
- modify_param! (rule, layer. weight)
61
- modify_bias! (rule, layer. bias)
62
- end
58
+ function modify_layer! (rule:: R , layer:: L ; ignore_bias= false ) where {R,L}
59
+ ! has_weight_and_bias (layer) && return nothing # skip all
60
+ modify_weight! (rule, layer. weight)
61
+
62
+ # Checks that skip bias modification:
63
+ ignore_bias && return nothing
64
+ isa (layer. bias, Flux. Zeros) && return nothing # skip if bias=Flux.Zeros (Flux <= v0.12)
65
+ isa (layer. bias, Bool) && ! layer. bias && return nothing # skip if bias=false (Flux >= v0.13)
66
+
67
+ modify_bias! (rule, layer. bias)
63
68
return nothing
64
69
end
70
+ modify_weight! (rule:: R , W) where {R} = modify_param! (rule, W)
71
+ modify_bias! (rule:: R , b) where {R} = modify_param! (rule, b)
65
72
66
73
"""
67
74
modify_param!(rule, W)
@@ -71,17 +78,15 @@ Inplace-modify parameters before computing the relevance.
71
78
"""
72
79
modify_param! (rule, param) = nothing # general fallback
73
80
74
- # Useful presets:
75
- modify_param! (:: Val{:mask_positive } , p) = p . = max .( zero ( eltype (p)), p)
76
- modify_param! (:: Val{:mask_negative } , p) = p . = min .( zero ( eltype (p)), p)
81
+ # Useful presets that allow us to work around bias-free layers :
82
+ modify_param! (:: Val{:keep_positive } , p) = keep_positive! ( p)
83
+ modify_param! (:: Val{:keep_negative } , p) = keep_negative! ( p)
77
84
78
- # Internal wrapper functions for bias-free layers.
79
- modify_bias! (rule:: R , b) where {R} = modify_param! (rule, b)
80
- modify_bias! (rule, b:: Flux.Zeros ) = nothing # skip if bias=Flux.Zeros (Flux <= v0.12)
81
- function modify_bias! (rule, b:: Bool ) # skip if bias=false (Flux >= v0.13)
82
- @assert b == false
83
- return nothing
84
- end
85
+ modify_weight! (:: Val{:keep_positive_zero_bias} , W) = keep_positive! (W)
86
+ modify_bias! (:: Val{:keep_positive_zero_bias} , b) = fill! (b, zero (eltype (b)))
87
+
88
+ modify_weight! (:: Val{:keep_negative_zero_bias} , W) = keep_negative! (W)
89
+ modify_bias! (:: Val{:keep_negative_zero_bias} , b) = fill! (b, zero (eltype (b)))
85
90
86
91
# Internal function that resets parameters by capturing them in a closure.
87
92
# Returns a function `reset!` that resets the parameters to their original state when called.
@@ -226,12 +231,12 @@ function lrp!(Rₖ, rule::ZBoxRule, layer::L, aₖ, Rₖ₊₁) where {L}
226
231
aₖ₊₁, pullback = Zygote. pullback (layer, aₖ)
227
232
228
233
# Compute pullback for W⁺, b⁺
229
- modify_layer! (Val{ :mask_positive } , layer)
234
+ modify_layer! (Val ( :keep_positive ) , layer)
230
235
aₖ₊₁⁺, pullback⁺ = Zygote. pullback (layer, l)
231
236
reset! ()
232
237
233
238
# Compute pullback for W⁻, b⁻
234
- modify_layer! (Val{ :mask_negative } , layer)
239
+ modify_layer! (Val ( :keep_negative ) , layer)
235
240
aₖ₊₁⁻, pullback⁻ = Zygote. pullback (layer, h)
236
241
reset! ()
237
242
@@ -246,12 +251,7 @@ function zbox_input(in::AbstractArray{T}, A::AbstractArray) where {T}
246
251
return convert .(T, A)
247
252
end
248
253
249
- # Special cases for rules that don't modify params for extra performance:
250
- for R in (ZeroRule, EpsilonRule)
251
- @eval get_layer_resetter (:: $R , l) = Returns (nothing )
252
- @eval lrp! (Rₖ, :: $R , :: DropoutLayer , aₖ, Rₖ₊₁) = (Rₖ .= Rₖ₊₁)
253
- @eval lrp! (Rₖ, :: $R , :: ReshapingLayer , aₖ, Rₖ₊₁) = (Rₖ .= reshape (Rₖ₊₁, size (aₖ)))
254
- end
254
+
255
255
256
256
# Special cases for rules that don't modify params for extra performance:
257
257
for R in (ZeroRule, EpsilonRule)
0 commit comments