Skip to content

Commit e7346fd

Browse files
authored
Add model canonization (#51)
* Add network canonization * Move types to new file * Refactor `strip_softmax`
1 parent f47d3dc commit e7346fd

File tree

8 files changed

+212
-27
lines changed

8 files changed

+212
-27
lines changed

docs/src/api.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ SmoothGrad
1414
```
1515

1616
`SmoothGrad` is a special case of `InputAugmentation`, which can be applied as a wrapper to any analyzer:
17-
```@doc
17+
```@docs
1818
InputAugmentation
1919
```
2020

@@ -41,6 +41,7 @@ LRP_CONFIG.supports_activation
4141
```@docs
4242
strip_softmax
4343
flatten_model
44+
canonize
4445
```
4546

4647
# Index

src/ExplainableAI.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ using PrettyTables
1818

1919
include("neuron_selection.jl")
2020
include("analyze_api.jl")
21+
include("types.jl")
2122
include("flux.jl")
2223
include("utils.jl")
24+
include("canonize.jl")
2325
include("input_augmentation.jl")
2426
include("gradient.jl")
2527
include("lrp_checks.jl")
@@ -46,6 +48,6 @@ export check_model
4648
export heatmap
4749

4850
# utils
49-
export strip_softmax, flatten_model, flatten_chain
51+
export strip_softmax, flatten_model, flatten_chain, canonize
5052

5153
end # module

src/canonize.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
function fuse_batchnorm(d::Dense, bn::BatchNorm)
2+
d.σ != identity &&
3+
throw(ArgumentError("Can't fuse Dense layer with activation $(d.σ)."))
4+
scale = safedivide(bn.γ, sqrt.(bn.σ²))
5+
W = scale .* d.weight
6+
b = scale .* (d.bias - bn.μ) + bn.β
7+
return Dense(W, b, bn.λ)
8+
end
9+
10+
function fuse_batchnorm(c::Conv, bn::BatchNorm)
11+
c.σ != identity && throw(ArgumentError("Can't fuse Conv layer with activation $(c.σ)."))
12+
scale = safedivide(bn.γ, sqrt.(bn.σ²))
13+
W = c.weight .* reshape(scale, 1, 1, 1, :)
14+
b = scale .* (c.bias - bn.μ) + bn.β
15+
return Conv(W, b, bn.λ)
16+
end
17+
18+
"""
19+
try_fusing(model, i)
20+
21+
Attempt to fuse pair of model layers at indices `i` and `i+1`.
22+
Returns fused model and `true` if layers were fused, unmodified model and `false` otherwise.
23+
"""
24+
function try_fusing(model, i)
25+
l1 = model[i]
26+
l2 = model[i + 1]
27+
if l1 isa Union{Dense,Conv} && l2 isa BatchNorm && activation(l1) == identity
28+
if i == length(model) - 1
29+
model = Chain(model[1:(i - 1)]..., fuse_batchnorm(l1, l2))
30+
end
31+
model = Chain(model[1:(i - 1)]..., fuse_batchnorm(l1, l2), model[(i + 2):end]...)
32+
return model, true
33+
end
34+
return model, false
35+
end
36+
37+
"""
38+
canonize(model)
39+
40+
Canonize model by flattening it and fusing BatchNorm layers into preceding Dense and Conv
41+
layers with linear activation functions.
42+
"""
43+
function canonize(model::Chain)
44+
model = flatten_model(model)
45+
i = 1
46+
while i < length(model)
47+
model, fused = try_fusing(model, i)
48+
!fused && (i += 1)
49+
end
50+
return model
51+
end

src/flux.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,14 @@
1-
## Group layers by type:
2-
const ConvLayer = Union{Conv} # TODO: DepthwiseConv, ConvTranspose, CrossCor
3-
const DropoutLayer = Union{Dropout,typeof(Flux.dropout),AlphaDropout}
4-
const ReshapingLayer = Union{typeof(Flux.flatten)}
5-
# Pooling layers
6-
const MaxPoolLayer = Union{MaxPool,AdaptiveMaxPool,GlobalMaxPool}
7-
const MeanPoolLayer = Union{MeanPool,AdaptiveMeanPool,GlobalMeanPool}
8-
const PoolingLayer = Union{MaxPoolLayer,MeanPoolLayer}
9-
# Activation functions that are similar to ReLU
10-
const ReluLikeActivation = Union{
11-
typeof(relu),typeof(gelu),typeof(swish),typeof(softplus),typeof(mish)
12-
}
13-
# Layers & activation functions supported by LRP
14-
const LRPSupportedLayer = Union{Dense,ConvLayer,DropoutLayer,ReshapingLayer,PoolingLayer}
15-
const LRPSupportedActivation = Union{typeof(identity),ReluLikeActivation}
1+
"""
2+
activation(layer)
3+
4+
Return activation function of the layer.
5+
In case the layer is unknown or no activation function is found, `nothing` is returned.
6+
"""
7+
activation(l::Dense) = l.σ
8+
activation(l::Conv) = l.σ
9+
activation(l::BatchNorm) = l.λ
10+
activation(layer) = nothing # default for all other layer types
1611

17-
_flatten_model(x) = x
18-
_flatten_model(c::Chain) = [c.layers...]
1912
"""
2013
flatten_model(c)
2114
@@ -30,8 +23,11 @@ function flatten_model(chain::Chain)
3023
end
3124
@deprecate flatten_chain(c) flatten_model(c)
3225

33-
is_softmax(layer) = layer isa Union{typeof(softmax),typeof(softmax!)}
34-
has_output_softmax(x) = is_softmax(x)
26+
_flatten_model(x) = x
27+
_flatten_model(c::Chain) = [c.layers...]
28+
29+
is_softmax(x) = x isa SoftmaxActivation
30+
has_output_softmax(x) = is_softmax(x) || is_softmax(activation(x))
3531
has_output_softmax(model::Chain) = has_output_softmax(model[end])
3632

3733
"""
@@ -56,10 +52,14 @@ Remove softmax activation on model output if it exists.
5652
function strip_softmax(model::Chain)
5753
if has_output_softmax(model)
5854
model = flatten_model(model)
59-
return Chain(model.layers[1:(end - 1)]...)
55+
if is_softmax(model[end])
56+
return Chain(model.layers[1:(end - 1)]...)
57+
end
58+
return Chain(model.layers[1:(end - 1)]..., strip_softmax(model[end]))
6059
end
6160
return model
6261
end
62+
strip_softmax(l::Union{Dense,Conv}) = set_params(l, l.weight, l.bias, identity)
6363

6464
# helper function to work around Flux.Zeros
6565
function get_params(layer)
@@ -76,5 +76,5 @@ end
7676
7777
Duplicate layer using weights W, b.
7878
"""
79-
set_params(l::Conv, W, b) = Conv(l.σ, W, b, l.stride, l.pad, l.dilation, l.groups)
80-
set_params(l::Dense, W, b) = Dense(W, b, l.σ)
79+
set_params(l::Conv, W, b, σ=l.σ) = Conv(σ, W, b, l.stride, l.pad, l.dilation, l.groups)
80+
set_params(l::Dense, W, b, σ=l.σ) = Dense(W, b, σ)

src/types.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
## Layer types
2+
"""Union type for convolutional layers."""
3+
const ConvLayer = Union{Conv} # TODO: DepthwiseConv, ConvTranspose, CrossCor
4+
5+
"""Union type for dropout layers."""
6+
const DropoutLayer = Union{Dropout,typeof(Flux.dropout),AlphaDropout}
7+
8+
"""Union type for reshaping layers such as `flatten`."""
9+
const ReshapingLayer = Union{typeof(Flux.flatten)}
10+
11+
"""Union type for max pooling layers."""
12+
const MaxPoolLayer = Union{MaxPool,AdaptiveMaxPool,GlobalMaxPool}
13+
14+
"""Union type for mean pooling layers."""
15+
const MeanPoolLayer = Union{MeanPool,AdaptiveMeanPool,GlobalMeanPool}
16+
17+
"""Union type for pooling layers."""
18+
const PoolingLayer = Union{MaxPoolLayer,MeanPoolLayer}
19+
20+
# Activation functions
21+
"""Union type for ReLU-like activation functions."""
22+
const ReluLikeActivation = Union{
23+
typeof(relu),typeof(gelu),typeof(swish),typeof(softplus),typeof(mish)
24+
}
25+
26+
"""Union type for softmax activation functions."""
27+
const SoftmaxActivation = Union{typeof(softmax),typeof(softmax!)}
28+
29+
# Layers & activation functions supported by LRP
30+
"""Union type for layers that are allowed by default in "deep rectifier networks"."""
31+
const LRPSupportedLayer = Union{Dense,ConvLayer,DropoutLayer,ReshapingLayer,PoolingLayer}
32+
33+
"""Union type for activation functions that are allowed by default in "deep rectifier networks"."""
34+
const LRPSupportedActivation = Union{typeof(identity),ReluLikeActivation}

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ using ReferenceTests
1919
println("Running tests on heatmaps...")
2020
include("test_heatmaps.jl")
2121
end
22+
@testset "Canonize" begin
23+
println("Running tests on model canonization...")
24+
include("test_canonize.jl")
25+
end
2226
@testset "LRP model checks" begin
2327
println("Running tests on LRP model checks...")
2428
include("test_checks.jl")

test/test_canonize.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
using Flux
2+
using ExplainableAI
3+
using ExplainableAI: fuse_batchnorm
4+
using Random
5+
6+
pseudorand(dims...) = rand(MersenneTwister(123), Float32, dims...)
7+
batchsize = 50
8+
9+
# # Test `fuse_batchnorm` on Dense layer
10+
ins = 20
11+
outs = 10
12+
dense = Dense(ins, outs; init=pseudorand)
13+
bn_dense = BatchNorm(outs, relu; initβ=pseudorand, initγ=pseudorand)
14+
model = Chain(dense, bn_dense)
15+
16+
# collect statistics
17+
x = pseudorand(ins, batchsize)
18+
Flux.trainmode!(model)
19+
model(x)
20+
Flux.testmode!(model)
21+
22+
dense_fused = @inferred fuse_batchnorm(dense, bn_dense)
23+
@test dense_fused(x) model(x)
24+
25+
# # Test `fuse_batchnorm` on Conv layer
26+
insize = (10, 10, 3)
27+
conv = Conv((3, 3), 3 => 4; init=pseudorand)
28+
bn_conv = BatchNorm(4, relu; initβ=pseudorand, initγ=pseudorand)
29+
model = Chain(conv, bn_conv)
30+
31+
# collect statistics
32+
x = pseudorand(insize..., batchsize)
33+
Flux.trainmode!(model)
34+
model(x)
35+
Flux.testmode!(model)
36+
37+
conv_fused = @inferred fuse_batchnorm(conv, bn_conv)
38+
@test conv_fused(x) model(x)
39+
40+
# # Test `canonize` on models
41+
# Sequential BatchNorm layers should be fused until they create a Dense or Conv layer
42+
# with non-linear activation function.
43+
model = Chain(
44+
Conv((3, 3), 3 => 6),
45+
BatchNorm(6),
46+
Conv((3, 3), 6 => 2, identity),
47+
BatchNorm(2),
48+
BatchNorm(2, softplus),
49+
BatchNorm(2),
50+
flatten,
51+
Dense(72, 10),
52+
BatchNorm(10),
53+
BatchNorm(10),
54+
BatchNorm(10, relu),
55+
BatchNorm(10),
56+
Dense(10, 10, gelu),
57+
BatchNorm(10),
58+
softmax,
59+
)
60+
Flux.trainmode!(model)
61+
model(x)
62+
Flux.testmode!(model)
63+
model_canonized = canonize(model)
64+
65+
# 6 of the BatchNorm layers should be removed and the ouputs should match
66+
@test length(model_canonized) == length(model) - 6
67+
@test model(x) model_canonized(x)

test/test_utils.jl

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
using Flux
2-
using ExplainableAI: flatten_model, has_output_softmax, check_output_softmax
2+
using ExplainableAI: flatten_model, has_output_softmax, check_output_softmax, activation
33
using ExplainableAI: stabilize_denom, batch_dim_view, drop_batch_index
4+
using Random
5+
6+
pseudorand(dims...) = rand(MersenneTwister(123), Float32, dims...)
7+
8+
# Test `activation`
9+
@test activation(Dense(5, 2, gelu)) == gelu
10+
@test activation(Conv((5, 5), 3 => 2, softplus)) == softplus
11+
@test activation(BatchNorm(5, selu)) == selu
12+
@test isnothing(activation(flatten))
413

514
# flatten_model
615
@test flatten_model(Chain(Chain(Chain(abs)), sqrt, Chain(relu))) == Chain(abs, sqrt, relu)
@@ -12,14 +21,31 @@ using ExplainableAI: stabilize_denom, batch_dim_view, drop_batch_index
1221
@test has_output_softmax(Chain(abs, sqrt, relu, tanh)) == false
1322
@test has_output_softmax(Chain(Chain(abs), sqrt, Chain(Chain(softmax)))) == true
1423
@test has_output_softmax(Chain(Chain(abs), Chain(Chain(softmax)), sqrt)) == false
24+
@test has_output_softmax(Chain(Dense(5, 5, softmax), Dense(5, 5, softmax))) == true
25+
@test has_output_softmax(Chain(Dense(5, 5, softmax), Dense(5, 5, relu))) == false
26+
@test has_output_softmax(Chain(Dense(5, 5, softmax), Chain(Dense(5, 5, softmax)))) == true
27+
@test has_output_softmax(Chain(Dense(5, 5, softmax), Chain(Dense(5, 5, relu)))) == false
1528

1629
# check_output_softmax
1730
@test_throws ArgumentError check_output_softmax(Chain(abs, sqrt, relu, softmax))
1831

1932
# strip_softmax
20-
@test strip_softmax(Chain(Chain(abs), sqrt, Chain(Chain(softmax)))) == Chain(abs, sqrt) # flatten to remove softmax
33+
d_softmax = Dense(2, 2, softmax; init=pseudorand)
34+
d_softmax2 = Dense(2, 2, softmax; init=pseudorand)
35+
d_relu = Dense(2, 2, relu; init=pseudorand)
36+
d_identity = Dense(2, 2; init=pseudorand)
37+
# flatten to remove softmax
38+
m = strip_softmax(Chain(Chain(abs), sqrt, Chain(Chain(softmax))))
39+
@test m == Chain(abs, sqrt)
40+
m1 = strip_softmax(Chain(d_relu, Chain(d_softmax)))
41+
m2 = Chain(d_relu, d_identity)
42+
x = rand(Float32, 2, 10)
43+
@test typeof(m1) == typeof(m2)
44+
@test m1(x) == m2(x)
45+
# don't do anything if there is no softmax at the end
2146
@test strip_softmax(Chain(Chain(abs), Chain(Chain(softmax)), sqrt)) ==
22-
Chain(Chain(abs), Chain(Chain(softmax)), sqrt) # don't do anything if there is no softmax at the end
47+
Chain(Chain(abs), Chain(Chain(softmax)), sqrt)
48+
@test strip_softmax(Chain(d_softmax, Chain(d_relu))) == Chain(d_softmax, Chain(d_relu))
2349

2450
# stabilize_denom
2551
A = [1.0 0.0 1.0e-25; -1.0 -0.0 -1.0e-25]

0 commit comments

Comments
 (0)