Skip to content
Draft
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c213416
add option to normalize / set output relevances
Maximilian-Stefan-Ernst Mar 7, 2024
1ed33fd
start adding package extension for vision transformers
Maximilian-Stefan-Ernst Mar 7, 2024
a80ba77
add NNlib to weakdeps
Maximilian-Stefan-Ernst Mar 7, 2024
4822aa7
fix merge conflict
Maximilian-Stefan-Ernst Mar 7, 2024
b07335a
fix dependencies for extension
Maximilian-Stefan-Ernst Mar 15, 2024
4f79949
add extension
Maximilian-Stefan-Ernst Mar 15, 2024
9f6d5e2
add options to specify last layer relevance
Maximilian-Stefan-Ernst Mar 15, 2024
bfc9e70
fix imports
Maximilian-Stefan-Ernst Mar 15, 2024
5604403
add canonization for SkipConnection layers; fix model splitting edge …
Maximilian-Stefan-Ernst Mar 15, 2024
27b22d2
Merge pull request #1 from Maximilian-Stefan-Ernst/hotfix/canonize
Maximilian-Stefan-Ernst Mar 15, 2024
655b51d
fix formatting
Maximilian-Stefan-Ernst Mar 15, 2024
f8762e1
fix flatten_model edge case
Maximilian-Stefan-Ernst Mar 15, 2024
4e14408
add tests for flatten_model and canonize
Maximilian-Stefan-Ernst Mar 15, 2024
8e606b5
Merge pull request #2 from Maximilian-Stefan-Ernst/hotfix/canonize
Maximilian-Stefan-Ernst Mar 15, 2024
72dcc76
fix in prepare_vit
Maximilian-Stefan-Ernst Mar 15, 2024
e7c7bac
formatter
Maximilian-Stefan-Ernst Mar 15, 2024
06b4280
Merge branch 'attention' into main
Maximilian-Stefan-Ernst Mar 19, 2024
13ba2b7
Merge pull request #3 from Maximilian-Stefan-Ernst/main
Maximilian-Stefan-Ernst Mar 19, 2024
e4721fd
Revert "add options to specify last layer relevance"
Maximilian-Stefan-Ernst Mar 23, 2024
d464696
Revert "add option to normalize / set output relevances"
Maximilian-Stefan-Ernst Mar 23, 2024
f680bd6
remove NNlib dependency
Maximilian-Stefan-Ernst Mar 23, 2024
31ac29b
rename extension
Maximilian-Stefan-Ernst Mar 23, 2024
6cac3e7
rename files to new extension name & make prepare_vit part of canonize
Maximilian-Stefan-Ernst Mar 23, 2024
e2c7e30
move extension rules to src/rules.jl
Maximilian-Stefan-Ernst Mar 23, 2024
17da330
tell users they have to load Metalhead to use the rules
Maximilian-Stefan-Ernst Mar 23, 2024
79b2d40
remove prepare_vit from exports; format
Maximilian-Stefan-Ernst Mar 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

[extensions]
VisionTransformerExt = ["Metalhead", "NNlib"]

[compat]
Flux = "0.13, 0.14"
MacroTools = "0.5"
Expand All @@ -23,3 +30,5 @@ Statistics = "1"
XAIBase = "3"
Zygote = "0.6"
julia = "1.6"
Metalhead = "0.9"
NNlib = "0.9"
13 changes: 13 additions & 0 deletions ext/VisionTransformerExt/VisionTransformerExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module VisionTransformerExt
using RelevancePropagation, Flux
using RelevancePropagation:
SelfAttentionRule, SelectClassToken, PositionalEmbeddingRule, modify_layer # all used types have to be used explicitely
import RelevancePropagation: prepare_vit, is_compatible, lrp! # all functions to which you want to add methods have to be imported
using Metalhead:
ViPosEmbedding, ClassTokens, MultiHeadSelfAttention, chunk, ViT, seconddimmean
using Metalhead.Layers: _flatten_spatial
using NNlib: split_heads, join_heads

include("rules.jl")
include("utils.jl")
end
114 changes: 114 additions & 0 deletions ext/VisionTransformerExt/rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# attention
SelfAttentionRule() = SelfAttentionRule(ZeroRule(), ZeroRule())
LRP_CONFIG.supports_layer(::MultiHeadSelfAttention) = true
is_compatible(::SelfAttentionRule, ::MultiHeadSelfAttention) = true

function lrp!(
Rᵏ, rule::SelfAttentionRule, mha::MultiHeadSelfAttention, _modified_layer, aᵏ, Rᵏ⁺¹
)
# query, key, value projections
qkv = mha.qkv_layer(aᵏ)
q, k, v = chunk(qkv, 3; dims=1)
Rᵥ = similar(v)
# attention
nheads = mha.nheads
fdrop = mha.attn_drop
bias = nothing
# reshape to merge batch dimensions
batch_size = size(q)[3:end]
batch_size == size(k)[3:end] == size(v)[3:end] ||
throw(ArgumentError("Batch dimensions have to be the same."))
q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v))
# add head dimension
q, k, v = split_heads.((q, k, v), nheads)
# compute attention scores
αt = dot_product_attention_scores(q, k, bias; fdrop)
# move head dimension to third place
vt = permutedims(v, (1, 3, 2, 4))
xt = vt ⊠ αt
# remove head dimension
x = permutedims(xt, (1, 3, 2, 4))
x = join_heads(x)
# restore batch dimensions
x = reshape(x, size(x, 1), size(x, 2), batch_size...)
Rₐ = similar(x)

# lrp pass
## forward: aᵏ ->(v_proj) v ->(attention) x ->(out_proj) out
## lrp: Rᵏ <-(value_rule) Rᵥ <-(AH-Rule) Rₐ <-(out_rule) Rᵏ⁺¹
## output projection
lrp!(
Rₐ,
rule.out_rule,
mha.projection[1],
modify_layer(rule.out_rule, mha.projection[1]),
x,
Rᵏ⁺¹,
)
## attention
lrp_attention!(Rᵥ, xt, αt, vt, nheads, batch_size, Rₐ)
## value projection
_, _, w = chunk(mha.qkv_layer.weight, 3; dims=1)
_, _, b = chunk(mha.qkv_layer.bias, 3; dims=1)
proj = Dense(w, b)
lrp!(Rᵏ, rule.value_rule, proj, modify_layer(rule.value_rule, proj), aᵏ, Rᵥ)
end

function lrp_attention!(Rᵥ, x, α, v, nheads, batch_size, Rₐ)
# input dimensions:
## Rₐ: [embedding x token x batch...]
## x : [embedding x token x head x batch]
## α : [token x token x head x batch]
## v : [embedding x token x head x batch]
## Rᵥ: [embedding x token x batch...]

# reshape Rₐ: combine batch dimensions, split heads, move head dimension
Rₐ = permutedims(
split_heads(reshape(Rₐ, size(Rₐ, 1), size(Rₐ, 2), :), nheads), (1, 3, 2, 4)
)
# compute relevance term
s = Rₐ ./ x
# add extra dimensions for broadcasting
s = reshape(s, size(s, 1), size(s, 2), 1, size(s)[3:end]...)
α = reshape(permutedims(α, (2, 1, 3, 4)), 1, size(α)...)
# compute relevances, broadcasting over extra dimensions
R = α .* s
R = dropdims(sum(R; dims=2); dims=2)
R = R .* v
# reshape relevances (drop extra dimension, move head dimension, join heads, split batch dimensions)
R = join_heads(permutedims(R, (1, 3, 2, 4)))
Rᵥ .= reshape(R, size(R, 1), size(R, 2), batch_size...)
end

#=========================#
# Special ViT layers #
#=========================#

# reshaping image -> token
LRP_CONFIG.supports_layer(::typeof(_flatten_spatial)) = true
function lrp!(Rᵏ, ::ZeroRule, ::typeof(_flatten_spatial), _modified_layer, aᵏ, Rᵏ⁺¹)
Rᵏ .= reshape(permutedims(Rᵏ⁺¹, (2, 1, 3)), size(Rᵏ)...)
end

# ClassToken layer: adds a Class Token; we ignore this token for the relevances
LRP_CONFIG.supports_layer(::ClassTokens) = true
function lrp!(Rᵏ, ::ZeroRule, ::ClassTokens, _modified_layer, aᵏ, Rᵏ⁺¹)
Rᵏ .= Rᵏ⁺¹[:, 2:end, :]
end

# Positional Embedding (you can also use the PassRule)
LRP_CONFIG.supports_layer(::ViPosEmbedding) = true
is_compatible(::PositionalEmbeddingRule, ::ViPosEmbedding) = true
function lrp!(
Rᵏ, ::PositionalEmbeddingRule, layer::ViPosEmbedding, _modified_layer, aᵏ, Rᵏ⁺¹
)
Rᵏ .= aᵏ ./ layer(aᵏ) .* Rᵏ⁺¹
end

# class token selection: only the class token is used for the final predictions,
# so it gets all the relevance
LRP_CONFIG.supports_layer(::SelectClassToken) = true
function lrp!(Rᵏ, ::ZeroRule, ::SelectClassToken, _modified_layer, aᵏ, Rᵏ⁺¹)
fill!(Rᵏ, zero(eltype(Rᵏ)))
Rᵏ[:, 1, :] .= Rᵏ⁺¹
end
9 changes: 9 additions & 0 deletions ext/VisionTransformerExt/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
function prepare_vit(model::ViT)
model = model.layers # remove wrapper type
model = flatten_model(model) # model consists of nested chains
testmode!(model) # make shure there is no dropout during forward pass
if !isa(model[end - 2], typeof(seconddimmean))
model = Chain(model[1:(end - 3)]..., SelectClassToken(), model[(end - 1):end]...) # swap anonymous function to actual layer
end
return model
end
5 changes: 3 additions & 2 deletions src/RelevancePropagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ include("lrp.jl")
include("show.jl")
include("composite_presets.jl") # uses show.jl
include("crp.jl")
include("extensions.jl")

export LRP
export CRP
Expand All @@ -36,7 +37,7 @@ export AbstractLRPRule
export LRP_CONFIG
export ZeroRule, EpsilonRule, GammaRule, WSquareRule, FlatRule
export PassRule, ZBoxRule, ZPlusRule, AlphaBetaRule, GeneralizedGammaRule
export LayerNormRule
export LayerNormRule, PositionalEmbeddingRule, SelfAttentionRule

# LRP composites
export Composite, AbstractCompositePrimitive
Expand All @@ -53,6 +54,6 @@ export EpsilonAlpha2Beta1Flat
export ConvLayer, PoolingLayer, DropoutLayer, ReshapingLayer, NormalizationLayer

# utils
export strip_softmax, flatten_model, canonize
export strip_softmax, flatten_model, canonize, prepare_vit

end # module
15 changes: 13 additions & 2 deletions src/canonize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,16 @@ function split_layer(l::LayerNorm)
diag = Scale(1, l.λ; bias=false)
diag.scale .= 1.0
end
return (layer_norm, diag)
return Chain(layer_norm, diag)
end

is_splittable(l::LayerNorm) = true
is_splittable(l::LayerNorm{F,D,T,N}) where {F,D<:typeof(identity),T,N} = false # don't split any further if the affine part is already the identity
is_splittable(l) = false

# fallback
split_layer(layer) = layer

#=================================#
# Canonize model (split and fuse) #
#=================================#
Expand Down Expand Up @@ -88,7 +91,11 @@ function canonize_split(p::Parallel)
return Parallel(p.connection, canonize_split.(p.layers))
end

canonize_split(layer) = layer
function canonize_split(s::SkipConnection)
return SkipConnection(canonize_split(s.layers), s.connection)
end

canonize_split(layer) = split_layer(layer)

function canonize_fuse(model::Chain)
model = Chain(canonize_fuse.(model.layers)) # recursively canonize Parallel layers
Expand All @@ -113,4 +120,8 @@ function canonize_fuse(p::Parallel)
return Parallel(p.connection, canonize_fuse.(p.layers))
end

function canonize_fuse(s::SkipConnection)
return SkipConnection(canonize_fuse(s.layers), s.connection)
end

canonize_fuse(layer) = layer
18 changes: 6 additions & 12 deletions src/chain_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,21 +187,15 @@ flatten_model(x) = chainflatten(x)
Flatten a Flux `Chain` containing `Chain`s. Also works with `ChainTuple`s.
"""
function chainflatten(c::Chain)
if length(c.layers) == 1
return Chain(_chainflatten(c))
else
return Chain(_chainflatten(c)...)
end
return Chain(_chainflatten(c)...)
end

function chainflatten(c::ChainTuple)
if length(c.vals) == 1
return ChainTuple(_chainflatten(c))
else
return ChainTuple(_chainflatten(c)...)
end
return ChainTuple(_chainflatten(c)...)
end
_chainflatten(c::Chain) = mapreduce(_chainflatten, vcat, c.layers)
_chainflatten(c::ChainTuple) = mapreduce(_chainflatten, vcat, c.vals)

_chainflatten(c::Chain) = mapreduce(_chainflatten, vcat, c.layers; init=[])
_chainflatten(c::ChainTuple) = mapreduce(_chainflatten, vcat, c.vals; init=[])

chainflatten(p::Parallel) = _chainflatten(p)
chainflatten(p::ParallelTuple) = _chainflatten(p)
Expand Down
56 changes: 56 additions & 0 deletions src/extensions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# since package extensions should/can not define new types or functions,
# we have to define them here and add the relevant methods in the extension

#=========================#
# VisionTransformerExt #
#=========================#

# layers
"""Flux layer to select the first token, for use with Metalhead.jl's vision transformer."""
struct SelectClassToken end
Flux.@functor SelectClassToken
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XAIBase exports generic feature selectors.
Maybe these could be used here and extended for transformers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that, but I don't get how these feature selectors are supposed to be used in a model / why there are no rules for them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay as we discussed, it does not really make sense to use the feature selectors. I think the remaining question is where you want to define new layers in the codebase - maybe an extra file src/layers.jl?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay as we discussed, it does not really make sense to use the feature selectors.

Sorry, it's been a while... Can you remind me what the exact issue was? 😅
I can vaguely remember it was something that should go in XAIBase.jl.

Similar to this: https://github.com/Julia-XAI/XAIBase.jl/blob/main/src/feature_selection.jl

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so vision transformers have a special token that is selected near the output, and all other tokens are discarded. They implement this in Metalhead through an anonymous function, so we can't use it for computing LRP. What I did was implementing this simple Flux layer (and an associated rule), that is swapped in for the anonymous function. The problem with the feature selector is that we need an actual layer, so I think we decided that this is probably not the right place ^^

(::SelectClassToken)(x) = x[:, 1, :]

# rules
"""
SelfAttentionRule(value_rule=ZeroRule(), out_rule=ZeroRule)
LRP-AH rule. Used on `MultiHeadSelfAttention` layers.
# Definition
Propagates relevance ``R^{k+1}`` at layer output to ``R^k`` at layer input according to
```math
R_i^k = \\sum_j\\frac{\\alpha_{ij} a_i^k}{\\sum_l \\alpha_{lj} a_l^k} R_j^{k+1}
```
where ``alpha_{ij}`` are the attention weights.
Relevance through the value projection (before attention) and the out projection (after attention) is by default propagated using the [`ZeroRule`](@ref).
# Optional arguments
- `value_rule`: Rule for the value projection, defaults to `ZeroRule()`
- `out_rule`: Rule for the out projection, defaults to `ZeroRule()`
# References
- $REF_ALI_TRANSFORMER
"""
struct SelfAttentionRule{V,O} <: AbstractLRPRule
value_rule::V
out_rule::O
end

"""
PositionalEmbeddingRule()
To be used with Metalhead.jl`s `ViPosEmbedding` layer. Treats the positional embedding like a bias term.
# Definition
Propagates relevance ``R^{k+1}`` at layer output to ``R^k`` at layer input according to
```math
R_i^k = \\frac{a_i^k}{a_i^k + e^i} R_i^{k+1}
```
where ``e^i`` is the learned positional embedding.
"""
struct PositionalEmbeddingRule <: AbstractLRPRule end

# utils
"""Prepare the vision transformer model for the use with `RelevancePropagation.jl`"""
function prepare_vit end
25 changes: 19 additions & 6 deletions src/lrp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,36 @@ LRP(model::Chain, c::Composite; kwargs...) = LRP(model, lrp_rules(model, c); kwa
#==========================#

function (lrp::LRP)(
input::AbstractArray, ns::AbstractOutputSelector; layerwise_relevances=false
input::AbstractArray,
ns::AbstractOutputSelector;
layerwise_relevances=false,
normalize_output=true,
R=nothing,
)
as = get_activations(lrp.model, input) # compute activations aᵏ for all layers k
Rs = similar.(as) # allocate relevances Rᵏ for all layers k
mask_output_neuron!(Rs[end], as[end], ns) # compute relevance Rᴺ of output layer N

Rs = similar.(as)
if isnothing(R) # allocate relevances Rᵏ for all layers k
mask_output_neuron!(Rs[end], as[end], ns; normalize_output=normalize_output) # compute relevance Rᴺ of output layer N
else
Rs[end] .= R # if there is a user-specified relevance for the last layer, use that instead
end
lrp_backward_pass!(Rs, as, lrp.rules, lrp.model, lrp.modified_layers)
extras = layerwise_relevances ? (layerwise_relevances=Rs,) : nothing
return Explanation(first(Rs), last(as), ns(last(as)), :LRP, :attribution, extras)
end

get_activations(model, input) = (input, Flux.activations(model, input)...)

function mask_output_neuron!(R_out, a_out, ns::AbstractOutputSelector)
function mask_output_neuron!(
R_out, a_out, ns::AbstractOutputSelector; normalize_output=true
)
fill!(R_out, 0)
idx = ns(a_out)
R_out[idx] .= 1
if normalize_output
R_out[idx] .= 1
else
R_out[idx] .= a_out[idx]
end
return R_out
end

Expand Down
Loading