Skip to content

Commit 509b8b5

Browse files
committed
Fix self-documenting composite presets
need to be included after `src/lrp/show.jl` which defines `Base.show` on `Composite`.
1 parent 52c6d1a commit 509b8b5

File tree

3 files changed

+155
-153
lines changed

3 files changed

+155
-153
lines changed

src/ExplainableAI.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ include("lrp/rules.jl")
3131
include("lrp/composite.jl")
3232
include("lrp/lrp.jl")
3333
include("lrp/show.jl")
34+
include("lrp/composite_presets.jl") # uses lrp/show.jl
3435
include("heatmap.jl")
3536
include("preprocessing.jl")
3637
export analyze

src/lrp/composite.jl

Lines changed: 52 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,10 @@
1-
"""
2-
Composite([default_rule=LRPZero()], primitives...)
3-
4-
Automatically contructs a list of LRP-rules by sequentially applying composite primitives.
5-
6-
# Primitives
7-
To apply a single rule, use:
8-
* [`LayerRule`](@ref) to apply a rule to the `n`-th layer of a model
9-
* [`GlobalRule`](@ref) to apply a rule to all layers
10-
* [`RangeRule`](@ref) to apply a rule to a positional range of layers
11-
* [`FirstLayerRule`](@ref) to apply a rule to the first layer
12-
* [`LastLayerRule`](@ref) to apply a rule to the last layer
13-
14-
To apply a set of rules to layers based on their type, use:
15-
* [`GlobalTypeRule`](@ref) to apply a dictionary that maps layer types to LRP-rules
16-
* [`RangeTypeRule`](@ref) for a `TypeRule` on generalized ranges
17-
* [`FirstLayerTypeRule`](@ref) for a `TypeRule` on the first layer of a model
18-
* [`LastLayerTypeRule`](@ref) for a `TypeRule` on the last layer
19-
* [`FirstNTypeRule`](@ref) for a `TypeRule` on the first `n` layers
20-
* [`LastNTypeRule`](@ref) for a `TypeRule` on the last `n` layers
21-
22-
# Example
23-
Using a flattened VGG11 model:
24-
```julia-repl
25-
julia> composite = Composite(
26-
GlobalTypeRule(
27-
ConvLayer => AlphaBetaRule(),
28-
Dense => EpsilonRule(),
29-
PoolingLayer => EpsilonRule(),
30-
DropoutLayer => PassRule(),
31-
ReshapingLayer => PassRule(),
32-
),
33-
FirstNTypeRule(7, Conv => FlatRule()),
34-
);
35-
36-
julia> analyzer = LRP(model, composite);
37-
38-
julia> analyzer.rules
39-
19-element Vector{AbstractLRPRule}:
40-
FlatRule()
41-
EpsilonRule{Float32}(1.0f-6)
42-
FlatRule()
43-
EpsilonRule{Float32}(1.0f-6)
44-
FlatRule()
45-
FlatRule()
46-
EpsilonRule{Float32}(1.0f-6)
47-
AlphaBetaRule{Float32}(2.0f0, 1.0f0)
48-
AlphaBetaRule{Float32}(2.0f0, 1.0f0)
49-
EpsilonRule{Float32}(1.0f-6)
50-
AlphaBetaRule{Float32}(2.0f0, 1.0f0)
51-
AlphaBetaRule{Float32}(2.0f0, 1.0f0)
52-
EpsilonRule{Float32}(1.0f-6)
53-
PassRule()
54-
EpsilonRule{Float32}(1.0f-6)
55-
PassRule()
56-
EpsilonRule{Float32}(1.0f-6)
57-
PassRule()
58-
EpsilonRule{Float32}(1.0f-6)
59-
```
60-
"""
1+
# A Composite is a container of primitives, which are sequentially applied
612
struct Composite{T<:Union{Tuple,AbstractVector}}
623
primitives::T
634
end
645
Composite(rule::AbstractLRPRule, prims...) = Composite((GlobalRule(rule), prims...))
656
Composite(prims...) = Composite(prims)
667

67-
# A Composite is a container of primitives, which are sequentially applied
688
const COMPOSITE_DEFAULT_RULE = ZeroRule()
699
function (c::Composite)(model)
7010
rules = Vector{AbstractLRPRule}(repeat([COMPOSITE_DEFAULT_RULE], length(model.layers)))
@@ -270,104 +210,63 @@ function _range_rule_map!(rules, layers, map, range)
270210
end
271211

272212
"""
273-
EpsilonGammaBox(low, high; [epsilon=1.0f-6, gamma=0.25f0])
274-
275-
Composite using the following primitives:
276-
```julia-repl
277-
julia> EpsilonGammaBox(-3.0f0, 3.0f0)
278-
$(repr("text/plain", EpsilonGammaBox(-3.0f0, 3.0f0)))
279-
```
280-
"""
281-
function EpsilonGammaBox(low, high; epsilon=1.0f-6, gamma=0.25f0)
282-
return Composite(
283-
GlobalTypeRule(
284-
ConvLayer => GammaRule(gamma),
285-
Dense => EpsilonRule(epsilon),
286-
DropoutLayer => PassRule(),
287-
ReshapingLayer => PassRule(),
288-
),
289-
FirstLayerTypeRule(ConvLayer => ZBoxRule(low, high)),
290-
)
291-
end
292-
293-
"""
294-
EpsilonPlus(; [epsilon=1.0f-6])
295-
296-
Composite using the following primitives:
297-
```julia-repl
298-
julia> EpsilonPlus()
299-
$(repr("text/plain", EpsilonPlus()))
300-
```
301-
"""
302-
function EpsilonPlus(; epsilon=1.0f-6)
303-
return Composite(
304-
GlobalTypeRule(
305-
ConvLayer => AlphaBetaRule(1.0f0, 0.0f0), # TODO: replace with ZPlusRule
306-
Dense => EpsilonRule(epsilon),
307-
DropoutLayer => PassRule(),
308-
ReshapingLayer => PassRule(),
309-
),
310-
)
311-
end
213+
Composite([default_rule=LRPZero()], primitives...)
312214
313-
"""
314-
EpsilonAlpha2Beta1(; [epsilon=1.0f-6])
215+
Automatically contructs a list of LRP-rules by sequentially applying composite primitives.
315216
316-
Composite using the following primitives:
317-
```julia-repl
318-
julia> EpsilonAlpha2Beta1()
319-
$(repr("text/plain", EpsilonAlpha2Beta1()))
320-
```
321-
"""
322-
function EpsilonAlpha2Beta1(; epsilon=1.0f-6)
323-
return Composite(
324-
GlobalTypeRule(
325-
ConvLayer => AlphaBetaRule(2.0f0, 1.0f0),
326-
Dense => EpsilonRule(epsilon),
327-
DropoutLayer => PassRule(),
328-
ReshapingLayer => PassRule(),
329-
),
330-
)
331-
end
217+
# Primitives
218+
To apply a single rule, use:
219+
* [`LayerRule`](@ref) to apply a rule to the `n`-th layer of a model
220+
* [`GlobalRule`](@ref) to apply a rule to all layers
221+
* [`RangeRule`](@ref) to apply a rule to a positional range of layers
222+
* [`FirstLayerRule`](@ref) to apply a rule to the first layer
223+
* [`LastLayerRule`](@ref) to apply a rule to the last layer
332224
333-
"""
334-
EpsilonPlusFlat(; [epsilon=1.0f-6])
225+
To apply a set of rules to layers based on their type, use:
226+
* [`GlobalTypeRule`](@ref) to apply a dictionary that maps layer types to LRP-rules
227+
* [`RangeTypeRule`](@ref) for a `TypeRule` on generalized ranges
228+
* [`FirstLayerTypeRule`](@ref) for a `TypeRule` on the first layer of a model
229+
* [`LastLayerTypeRule`](@ref) for a `TypeRule` on the last layer
230+
* [`FirstNTypeRule`](@ref) for a `TypeRule` on the first `n` layers
231+
* [`LastNTypeRule`](@ref) for a `TypeRule` on the last `n` layers
335232
336-
Composite using the following primitives:
233+
# Example
234+
Using a flattened VGG11 model:
337235
```julia-repl
338-
julia> EpsilonPlusFlat()
339-
$(repr("text/plain", EpsilonPlusFlat()))
340-
```
341-
"""
342-
function EpsilonPlusFlat(; epsilon=1.0f-6)
343-
return Composite(
344-
GlobalTypeRule(
345-
ConvLayer => AlphaBetaRule(1.0f0, 0.0f0), # TODO: replace with ZPlusRule
346-
Dense => EpsilonRule(epsilon),
347-
DropoutLayer => PassRule(),
348-
ReshapingLayer => PassRule(),
349-
),
350-
FirstLayerTypeRule(ConvLayer => FlatRule(), Dense => FlatRule()),
351-
)
352-
end
236+
julia> composite = Composite(
237+
GlobalTypeRule(
238+
ConvLayer => AlphaBetaRule(),
239+
Dense => EpsilonRule(),
240+
PoolingLayer => EpsilonRule(),
241+
DropoutLayer => PassRule(),
242+
ReshapingLayer => PassRule(),
243+
),
244+
FirstNTypeRule(7, Conv => FlatRule()),
245+
);
353246
354-
"""
355-
EpsilonAlpha2Beta1Flat(; [epsilon=1.0f-6])
247+
julia> analyzer = LRP(model, composite);
356248
357-
Composite using the following primitives:
358-
```julia-repl
359-
julia> EpsilonAlpha2Beta1Flat()
360-
$(repr("text/plain", EpsilonAlpha2Beta1Flat()))
249+
julia> analyzer.rules
250+
19-element Vector{AbstractLRPRule}:
251+
FlatRule()
252+
EpsilonRule{Float32}(1.0f-6)
253+
FlatRule()
254+
EpsilonRule{Float32}(1.0f-6)
255+
FlatRule()
256+
FlatRule()
257+
EpsilonRule{Float32}(1.0f-6)
258+
AlphaBetaRule{Float32}(2.0f0, 1.0f0)
259+
AlphaBetaRule{Float32}(2.0f0, 1.0f0)
260+
EpsilonRule{Float32}(1.0f-6)
261+
AlphaBetaRule{Float32}(2.0f0, 1.0f0)
262+
AlphaBetaRule{Float32}(2.0f0, 1.0f0)
263+
EpsilonRule{Float32}(1.0f-6)
264+
PassRule()
265+
EpsilonRule{Float32}(1.0f-6)
266+
PassRule()
267+
EpsilonRule{Float32}(1.0f-6)
268+
PassRule()
269+
EpsilonRule{Float32}(1.0f-6)
361270
```
362271
"""
363-
function EpsilonAlpha2Beta1Flat(; epsilon=1.0f-6)
364-
return Composite(
365-
GlobalTypeRule(
366-
ConvLayer => AlphaBetaRule(2.0f0, 1.0f0),
367-
Dense => EpsilonRule(epsilon),
368-
DropoutLayer => PassRule(),
369-
ReshapingLayer => PassRule(),
370-
),
371-
FirstLayerTypeRule(ConvLayer => FlatRule(), Dense => FlatRule()),
372-
)
373-
end
272+
Composite

src/lrp/composite_presets.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""
2+
EpsilonGammaBox(low, high; [epsilon=1.0f-6, gamma=0.25f0])
3+
4+
Composite using the following primitives:
5+
```julia-repl
6+
julia> EpsilonGammaBox(-3.0f0, 3.0f0)
7+
$(repr("text/plain", EpsilonGammaBox(-3.0f0, 3.0f0)))
8+
```
9+
"""
10+
function EpsilonGammaBox(low, high; epsilon=1.0f-6, gamma=0.25f0)
11+
return Composite(
12+
GlobalTypeRule(
13+
ConvLayer => GammaRule(gamma),
14+
Dense => EpsilonRule(epsilon),
15+
DropoutLayer => PassRule(),
16+
ReshapingLayer => PassRule(),
17+
),
18+
FirstLayerTypeRule(ConvLayer => ZBoxRule(low, high)),
19+
)
20+
end
21+
22+
"""
23+
EpsilonPlus(; [epsilon=1.0f-6])
24+
25+
Composite using the following primitives:
26+
```julia-repl
27+
julia> EpsilonPlus()
28+
$(repr("text/plain", EpsilonPlus()))
29+
```
30+
"""
31+
function EpsilonPlus(; epsilon=1.0f-6)
32+
return Composite(
33+
GlobalTypeRule(
34+
ConvLayer => AlphaBetaRule(1.0f0, 0.0f0), # TODO: replace with ZPlusRule
35+
Dense => EpsilonRule(epsilon),
36+
DropoutLayer => PassRule(),
37+
ReshapingLayer => PassRule(),
38+
),
39+
)
40+
end
41+
42+
"""
43+
EpsilonAlpha2Beta1(; [epsilon=1.0f-6])
44+
45+
Composite using the following primitives:
46+
```julia-repl
47+
julia> EpsilonAlpha2Beta1()
48+
$(repr("text/plain", EpsilonAlpha2Beta1()))
49+
```
50+
"""
51+
function EpsilonAlpha2Beta1(; epsilon=1.0f-6)
52+
return Composite(
53+
GlobalTypeRule(
54+
ConvLayer => AlphaBetaRule(2.0f0, 1.0f0),
55+
Dense => EpsilonRule(epsilon),
56+
DropoutLayer => PassRule(),
57+
ReshapingLayer => PassRule(),
58+
),
59+
)
60+
end
61+
62+
"""
63+
EpsilonPlusFlat(; [epsilon=1.0f-6])
64+
65+
Composite using the following primitives:
66+
```julia-repl
67+
julia> EpsilonPlusFlat()
68+
$(repr("text/plain", EpsilonPlusFlat()))
69+
```
70+
"""
71+
function EpsilonPlusFlat(; epsilon=1.0f-6)
72+
return Composite(
73+
GlobalTypeRule(
74+
ConvLayer => AlphaBetaRule(1.0f0, 0.0f0), # TODO: replace with ZPlusRule
75+
Dense => EpsilonRule(epsilon),
76+
DropoutLayer => PassRule(),
77+
ReshapingLayer => PassRule(),
78+
),
79+
FirstLayerTypeRule(ConvLayer => FlatRule(), Dense => FlatRule()),
80+
)
81+
end
82+
83+
"""
84+
EpsilonAlpha2Beta1Flat(; [epsilon=1.0f-6])
85+
86+
Composite using the following primitives:
87+
```julia-repl
88+
julia> EpsilonAlpha2Beta1Flat()
89+
$(repr("text/plain", EpsilonAlpha2Beta1Flat()))
90+
```
91+
"""
92+
function EpsilonAlpha2Beta1Flat(; epsilon=1.0f-6)
93+
return Composite(
94+
GlobalTypeRule(
95+
ConvLayer => AlphaBetaRule(2.0f0, 1.0f0),
96+
Dense => EpsilonRule(epsilon),
97+
DropoutLayer => PassRule(),
98+
ReshapingLayer => PassRule(),
99+
),
100+
FirstLayerTypeRule(ConvLayer => FlatRule(), Dense => FlatRule()),
101+
)
102+
end

0 commit comments

Comments
 (0)