Skip to content

Commit ac52d64

Browse files
authored
Add default composites (#87)
* rename `RuleMap` to `TypeRule` * add `FirstLayerTypeRule` and `LastLayerTypeRule` * Add default composites * Document default composites * Add compat entries for test dependencies
1 parent 526a5b7 commit ac52d64

16 files changed

+427
-150
lines changed

Project.toml

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,3 @@ PrettyTables = "1"
2727
Tullio = "0.3"
2828
Zygote = "0.6"
2929
julia = "1.6"
30-
31-
[extras]
32-
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
33-
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
34-
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
35-
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
36-
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
37-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
38-
39-
[targets]
40-
test = ["JLD2", "LoopVectorization", "Random", "ReferenceTests", "Suppressor", "Test"]

docs/literate/advanced_lrp.jl

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ index = 10
1919
x, _ = MNIST(Float32, :test)[10]
2020
input = reshape(x, 28, 28, 1, :);
2121

22-
# ## Custom LRP composites
22+
# ## LRP composites
23+
# ### Custom composites
2324
# When creating an LRP-analyzer, we can assign individual rules to each layer.
2425
# The array of rules has to match the length of the Flux chain:
2526
rules = [
@@ -49,11 +50,11 @@ heatmap(input, analyzer)
4950
# To obtain the same set of rules as in the previous example, we can define
5051
composite = Composite(
5152
ZeroRule(), # default rule
52-
GlobalRuleMap(
53+
GlobalTypeRule(
5354
Conv => GammaRule(), # apply GammaRule on all convolutional layers
5455
MaxPool => EpsilonRule(), # apply EpsilonRule on all pooling-layers
5556
),
56-
FirstRule(ZBoxRule(0.0f0, 1.0f0)), # apply ZBoxRule on the first layer
57+
FirstLayerRule(ZBoxRule(0.0f0, 1.0f0)), # apply ZBoxRule on the first layer
5758
)
5859

5960
analyzer = LRP(model, composite) # construct LRP analyzer from composite
@@ -63,24 +64,31 @@ heatmap(input, analyzer)
6364
analyzer.rules # show rules
6465

6566
# ### Composite primitives
66-
# The following sets of primitives can used to construct a [`Composite`](@ref).
67+
# The following [Composite primitives](@ref composite_primitive_api) can used to construct a [`Composite`](@ref).
6768
#
6869
# To apply a single rule, use:
6970
# * [`LayerRule`](@ref) to apply a rule to the `n`-th layer of a model
70-
# * [`GlobalRule`](@ref) to apply a rule to all layers of a model
71-
# * [`RangeRule`](@ref) to apply a rule to a positional range of layers of a model
72-
# * [`FirstRule`](@ref) to apply a rule to the first layer of a model
73-
# * [`LastRule`](@ref) to apply a rule to the last layer of a model
71+
# * [`GlobalRule`](@ref) to apply a rule to all layers
72+
# * [`RangeRule`](@ref) to apply a rule to a positional range of layers
73+
# * [`FirstLayerRule`](@ref) to apply a rule to the first layer
74+
# * [`LastLayerRule`](@ref) to apply a rule to the last layer
7475
#
75-
# To apply a set of rules to multiple layers, use:
76-
# * [`GlobalRuleMap`](@ref) to apply a dictionary that maps layer types to LRP-rules
77-
# * [`RangeRuleMap`](@ref) for a `RuleMap` on generalized ranges
78-
# * [`FirstNRuleMap`](@ref) for a `RuleMap` on the first `n` layers of a model
79-
# * [`LastNRuleMap`](@ref) for a `RuleMap` on the last `n` layers
76+
# To apply a set of rules to layers based on their type, use:
77+
# * [`GlobalTypeRule`](@ref) to apply a dictionary that maps layer types to LRP-rules
78+
# * [`RangeTypeRule`](@ref) for a `TypeRule` on generalized ranges
79+
# * [`FirstLayerTypeRule`](@ref) for a `TypeRule` on the first layer of a model
80+
# * [`LastLayerTypeRule`](@ref) for a `TypeRule` on the last layer
81+
# * [`FirstNTypeRule`](@ref) for a `TypeRule` on the first `n` layers
82+
# * [`LastNTypeRule`](@ref) for a `TypeRule` on the last `n` layers
8083
#
8184
# Primitives are called sequentially in the order the `Composite` was created with
8285
# and overwrite rules specified by previous primitives.
8386

87+
# ### Default composites
88+
# A list of implemented default composites can be found under
89+
# [Default composites](@ref default_composite_api) in the API reference, e.g. [`EpsilonPlusFlat`](@ref):
90+
EpsilonPlusFlat()
91+
8492
# ## Custom LRP rules
8593
# Let's define a rule that modifies the weights and biases of our layer on the forward pass.
8694
# The rule has to be of type `AbstractLRPRule`.

docs/src/api.md

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,21 +50,33 @@ LRP_CONFIG.supports_activation
5050
Composite
5151
```
5252

53+
### [Composite primitives](@id composite_primitive_api)
5354
Composite primitives that apply a single rule:
5455
```@docs
5556
LayerRule
5657
GlobalRule
5758
RangeRule
58-
FirstRule
59-
LastRule
59+
FirstLayerRule
60+
LastLayerRule
6061
```
6162

6263
Composite primitives that apply a set of rules to multiple layers:
6364
```@docs
64-
GlobalRuleMap
65-
RangeRuleMap
66-
FirstNRuleMap
67-
LastNRuleMap
65+
GlobalTypeRule
66+
RangeTypeRule
67+
FirstLayerTypeRule
68+
LastLayerTypeRule
69+
FirstNTypeRule
70+
LastNTypeRule
71+
```
72+
73+
### [Default composites](@id default_composite_api)
74+
```@docs
75+
EpsilonGammaBox
76+
EpsilonPlus
77+
EpsilonAlpha2Beta1
78+
EpsilonPlusFlat
79+
EpsilonAlpha2Beta1Flat
6880
```
6981

7082
# Utilities

src/ExplainableAI.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,13 @@ export check_model
5353

5454
# LRP composites
5555
export Composite, AbstractCompositePrimitive
56-
export LayerRule, GlobalRule, RangeRule, FirstRule, LastRule
57-
export GlobalRuleMap, RangeRuleMap, FirstNRuleMap, LastNRuleMap
56+
export LayerRule, GlobalRule, RangeRule, FirstLayerRule, LastLayerRule
57+
export GlobalTypeRule, RangeTypeRule, FirstLayerTypeRule, LastLayerTypeRule
58+
export FirstNTypeRule, LastNTypeRule
59+
# Default composites
60+
export EpsilonGammaBox, EpsilonPlus, EpsilonAlpha2Beta1, EpsilonPlusFlat
61+
export EpsilonAlpha2Beta1Flat
62+
# Useful type unions
5863
export ConvLayer, PoolingLayer, DropoutLayer, ReshapingLayer
5964

6065
# heatmapping

0 commit comments

Comments
 (0)