Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c213416
add option to normalize / set output relevances
Maximilian-Stefan-Ernst Mar 7, 2024
1ed33fd
start adding package extension for vision transformers
Maximilian-Stefan-Ernst Mar 7, 2024
a80ba77
add NNlib to weakdeps
Maximilian-Stefan-Ernst Mar 7, 2024
4822aa7
fix merge conflict
Maximilian-Stefan-Ernst Mar 7, 2024
b07335a
fix dependencies for extension
Maximilian-Stefan-Ernst Mar 15, 2024
4f79949
add extension
Maximilian-Stefan-Ernst Mar 15, 2024
9f6d5e2
add options to specify last layer relevance
Maximilian-Stefan-Ernst Mar 15, 2024
bfc9e70
fix imports
Maximilian-Stefan-Ernst Mar 15, 2024
5604403
add canonization for SkipConnection layers; fix model splitting edge …
Maximilian-Stefan-Ernst Mar 15, 2024
27b22d2
Merge pull request #1 from Maximilian-Stefan-Ernst/hotfix/canonize
Maximilian-Stefan-Ernst Mar 15, 2024
655b51d
fix formatting
Maximilian-Stefan-Ernst Mar 15, 2024
f8762e1
fix flatten_model edge case
Maximilian-Stefan-Ernst Mar 15, 2024
4e14408
add tests for flatten_model and canonize
Maximilian-Stefan-Ernst Mar 15, 2024
8e606b5
Merge pull request #2 from Maximilian-Stefan-Ernst/hotfix/canonize
Maximilian-Stefan-Ernst Mar 15, 2024
72dcc76
fix in prepare_vit
Maximilian-Stefan-Ernst Mar 15, 2024
e7c7bac
formatter
Maximilian-Stefan-Ernst Mar 15, 2024
06b4280
Merge branch 'attention' into main
Maximilian-Stefan-Ernst Mar 19, 2024
13ba2b7
Merge pull request #3 from Maximilian-Stefan-Ernst/main
Maximilian-Stefan-Ernst Mar 19, 2024
e4721fd
Revert "add options to specify last layer relevance"
Maximilian-Stefan-Ernst Mar 23, 2024
d464696
Revert "add option to normalize / set output relevances"
Maximilian-Stefan-Ernst Mar 23, 2024
f680bd6
remove NNlib dependency
Maximilian-Stefan-Ernst Mar 23, 2024
31ac29b
rename extension
Maximilian-Stefan-Ernst Mar 23, 2024
6cac3e7
rename files to new extension name & make prepare_vit part of canonize
Maximilian-Stefan-Ernst Mar 23, 2024
e2c7e30
move extension rules to src/rules.jl
Maximilian-Stefan-Ernst Mar 23, 2024
17da330
tell users they have to load Metalhead to use the rules
Maximilian-Stefan-Ernst Mar 23, 2024
79b2d40
remove prepare_vit from exports; format
Maximilian-Stefan-Ernst Mar 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"

[extensions]
RelevancePropagationMetalheadExt = ["Metalhead"]

[compat]
Flux = "0.13, 0.14"
MacroTools = "0.5"
Expand All @@ -23,3 +29,4 @@ Statistics = "1"
XAIBase = "3"
Zygote = "0.6"
julia = "1.6"
Metalhead = "0.9"
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module RelevancePropagationMetalheadExt
using RelevancePropagation, Flux
using RelevancePropagation:
SelfAttentionRule, SelectClassToken, PositionalEmbeddingRule, modify_layer # all used types have to be used explicitely
import RelevancePropagation: canonize, is_compatible, lrp! # all functions to which you want to add methods have to be imported
using Metalhead:
ViPosEmbedding, ClassTokens, MultiHeadSelfAttention, chunk, ViT, seconddimmean
using Metalhead.Layers: _flatten_spatial

include("utils.jl")
include("rules.jl")
end
114 changes: 114 additions & 0 deletions ext/RelevancePropagationMetalheadExt/rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# attention
SelfAttentionRule() = SelfAttentionRule(ZeroRule(), ZeroRule())
LRP_CONFIG.supports_layer(::MultiHeadSelfAttention) = true
is_compatible(::SelfAttentionRule, ::MultiHeadSelfAttention) = true

function lrp!(
Rᵏ, rule::SelfAttentionRule, mha::MultiHeadSelfAttention, _modified_layer, aᵏ, Rᵏ⁺¹
)
# query, key, value projections
qkv = mha.qkv_layer(aᵏ)
q, k, v = chunk(qkv, 3; dims=1)
Rᵥ = similar(v)
# attention
nheads = mha.nheads
fdrop = mha.attn_drop
bias = nothing
# reshape to merge batch dimensions
batch_size = size(q)[3:end]
batch_size == size(k)[3:end] == size(v)[3:end] ||
throw(ArgumentError("Batch dimensions have to be the same."))
q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v))
# add head dimension
q, k, v = split_heads.((q, k, v), nheads)
# compute attention scores
αt = dot_product_attention_scores(q, k, bias; fdrop)
# move head dimension to third place
vt = permutedims(v, (1, 3, 2, 4))
xt = vt ⊠ αt
# remove head dimension
x = permutedims(xt, (1, 3, 2, 4))
x = join_heads(x)
# restore batch dimensions
x = reshape(x, size(x, 1), size(x, 2), batch_size...)
Rₐ = similar(x)

# lrp pass
## forward: aᵏ ->(v_proj) v ->(attention) x ->(out_proj) out
## lrp: Rᵏ <-(value_rule) Rᵥ <-(AH-Rule) Rₐ <-(out_rule) Rᵏ⁺¹
## output projection
lrp!(
Rₐ,
rule.out_rule,
mha.projection[1],
modify_layer(rule.out_rule, mha.projection[1]),
x,
Rᵏ⁺¹,
)
## attention
lrp_attention!(Rᵥ, xt, αt, vt, nheads, batch_size, Rₐ)
## value projection
_, _, w = chunk(mha.qkv_layer.weight, 3; dims=1)
_, _, b = chunk(mha.qkv_layer.bias, 3; dims=1)
proj = Dense(w, b)
lrp!(Rᵏ, rule.value_rule, proj, modify_layer(rule.value_rule, proj), aᵏ, Rᵥ)
end

function lrp_attention!(Rᵥ, x, α, v, nheads, batch_size, Rₐ)
# input dimensions:
## Rₐ: [embedding x token x batch...]
## x : [embedding x token x head x batch]
## α : [token x token x head x batch]
## v : [embedding x token x head x batch]
## Rᵥ: [embedding x token x batch...]

# reshape Rₐ: combine batch dimensions, split heads, move head dimension
Rₐ = permutedims(
split_heads(reshape(Rₐ, size(Rₐ, 1), size(Rₐ, 2), :), nheads), (1, 3, 2, 4)
)
# compute relevance term
s = Rₐ ./ x
# add extra dimensions for broadcasting
s = reshape(s, size(s, 1), size(s, 2), 1, size(s)[3:end]...)
α = reshape(permutedims(α, (2, 1, 3, 4)), 1, size(α)...)
# compute relevances, broadcasting over extra dimensions
R = α .* s
R = dropdims(sum(R; dims=2); dims=2)
R = R .* v
# reshape relevances (drop extra dimension, move head dimension, join heads, split batch dimensions)
R = join_heads(permutedims(R, (1, 3, 2, 4)))
Rᵥ .= reshape(R, size(R, 1), size(R, 2), batch_size...)
end

#=========================#
# Special ViT layers #
#=========================#

# reshaping image -> token
LRP_CONFIG.supports_layer(::typeof(_flatten_spatial)) = true
function lrp!(Rᵏ, ::ZeroRule, ::typeof(_flatten_spatial), _modified_layer, aᵏ, Rᵏ⁺¹)
Rᵏ .= reshape(permutedims(Rᵏ⁺¹, (2, 1, 3)), size(Rᵏ)...)
end

# ClassToken layer: adds a Class Token; we ignore this token for the relevances
LRP_CONFIG.supports_layer(::ClassTokens) = true
function lrp!(Rᵏ, ::ZeroRule, ::ClassTokens, _modified_layer, aᵏ, Rᵏ⁺¹)
Rᵏ .= Rᵏ⁺¹[:, 2:end, :]
end

# Positional Embedding (you can also use the PassRule)
LRP_CONFIG.supports_layer(::ViPosEmbedding) = true
is_compatible(::PositionalEmbeddingRule, ::ViPosEmbedding) = true
function lrp!(
Rᵏ, ::PositionalEmbeddingRule, layer::ViPosEmbedding, _modified_layer, aᵏ, Rᵏ⁺¹
)
Rᵏ .= aᵏ ./ layer(aᵏ) .* Rᵏ⁺¹
end

# class token selection: only the class token is used for the final predictions,
# so it gets all the relevance
LRP_CONFIG.supports_layer(::SelectClassToken) = true
function lrp!(Rᵏ, ::ZeroRule, ::SelectClassToken, _modified_layer, aᵏ, Rᵏ⁺¹)
fill!(Rᵏ, zero(eltype(Rᵏ)))
Rᵏ[:, 1, :] .= Rᵏ⁺¹
end
14 changes: 14 additions & 0 deletions ext/RelevancePropagationMetalheadExt/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
function canonize(model::ViT)
model = model.layers # remove wrapper type
model = flatten_model(model) # model consists of nested chains
testmode!(model) # make shure there is no dropout during forward pass
if !isa(model[end - 2], typeof(seconddimmean))
model = Chain(model[1:(end - 3)]..., SelectClassToken(), model[(end - 1):end]...) # swap anonymous function to actual layer
end
return canonize(model)
end

# these are originally from NNlib.jl, but since they are unexported, we don't want
# to rely on them an re-define them here
split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...)
join_heads(x) = reshape(x, :, size(x)[3:end]...)
2 changes: 1 addition & 1 deletion src/RelevancePropagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ export AbstractLRPRule
export LRP_CONFIG
export ZeroRule, EpsilonRule, GammaRule, WSquareRule, FlatRule
export PassRule, ZBoxRule, ZPlusRule, AlphaBetaRule, GeneralizedGammaRule
export LayerNormRule
export LayerNormRule, PositionalEmbeddingRule, SelfAttentionRule

# LRP composites
export Composite, AbstractCompositePrimitive
Expand Down
52 changes: 52 additions & 0 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -593,3 +593,55 @@ function lrp!(Rᵏ, _rule::FlatRule, _layer::Dense, _modified_layer, _aᵏ, Rᵏ
fill!(view(Rᵏ, :, i), sum(view(Rᵏ⁺¹, :, i)) / n)
end
end

#============#
# Extensions #
#============#
# since package extensions should/can not define new types or functions,
# we have to define them here and add the relevant methods in the extension

# RelevancePropagationMetalheadExt

"""Flux layer to select the first token, for use with Metalhead.jl's vision transformer."""
struct SelectClassToken end
Flux.@functor SelectClassToken
(::SelectClassToken)(x) = x[:, 1, :]

"""
SelfAttentionRule(value_rule=ZeroRule(), out_rule=ZeroRule)

LRP-AH rule. Used on `MultiHeadSelfAttention` layers from Metalhead.jl. Metalhead.jl has to be loaded to make use of this Rule.

# Definition
Propagates relevance ``R^{k+1}`` at layer output to ``R^k`` at layer input according to
```math
R_i^k = \\sum_j\\frac{\\alpha_{ij} a_i^k}{\\sum_l \\alpha_{lj} a_l^k} R_j^{k+1}
```
where ``alpha_{ij}`` are the attention weights.
Relevance through the value projection (before attention) and the out projection (after attention) is by default propagated using the [`ZeroRule`](@ref).

# Optional arguments
- `value_rule`: Rule for the value projection, defaults to `ZeroRule()`
- `out_rule`: Rule for the out projection, defaults to `ZeroRule()`

# References
- $REF_ALI_TRANSFORMER
"""
struct SelfAttentionRule{V,O} <: AbstractLRPRule
value_rule::V
out_rule::O
end

"""
PositionalEmbeddingRule()

To be used with Metalhead.jl`s `ViPosEmbedding` layer. Treats the positional embedding like a bias term. Metalhead.jl has to be loaded to make use of this Rule.

# Definition
Propagates relevance ``R^{k+1}`` at layer output to ``R^k`` at layer input according to
```math
R_i^k = \\frac{a_i^k}{a_i^k + e^i} R_i^{k+1}
```
where ``e^i`` is the learned positional embedding.
"""
struct PositionalEmbeddingRule <: AbstractLRPRule end