-
Notifications
You must be signed in to change notification settings - Fork 2
Add support for (vision) transformers, add options to set last-layer relevance #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 16 commits
c213416
1ed33fd
a80ba77
4822aa7
b07335a
4f79949
9f6d5e2
bfc9e70
5604403
27b22d2
655b51d
f8762e1
4e14408
8e606b5
72dcc76
e7c7bac
06b4280
13ba2b7
e4721fd
d464696
f680bd6
31ac29b
6cac3e7
e2c7e30
17da330
79b2d40
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
module VisionTransformerExt | ||
using RelevancePropagation, Flux | ||
using RelevancePropagation: | ||
SelfAttentionRule, SelectClassToken, PositionalEmbeddingRule, modify_layer # all used types have to be used explicitely | ||
import RelevancePropagation: prepare_vit, is_compatible, lrp! # all functions to which you want to add methods have to be imported | ||
using Metalhead: | ||
ViPosEmbedding, ClassTokens, MultiHeadSelfAttention, chunk, ViT, seconddimmean | ||
using Metalhead.Layers: _flatten_spatial | ||
using NNlib: split_heads, join_heads | ||
Maximilian-Stefan-Ernst marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
include("rules.jl") | ||
include("utils.jl") | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# attention | ||
SelfAttentionRule() = SelfAttentionRule(ZeroRule(), ZeroRule()) | ||
LRP_CONFIG.supports_layer(::MultiHeadSelfAttention) = true | ||
is_compatible(::SelfAttentionRule, ::MultiHeadSelfAttention) = true | ||
|
||
function lrp!( | ||
Rᵏ, rule::SelfAttentionRule, mha::MultiHeadSelfAttention, _modified_layer, aᵏ, Rᵏ⁺¹ | ||
) | ||
# query, key, value projections | ||
qkv = mha.qkv_layer(aᵏ) | ||
q, k, v = chunk(qkv, 3; dims=1) | ||
Rᵥ = similar(v) | ||
# attention | ||
nheads = mha.nheads | ||
fdrop = mha.attn_drop | ||
bias = nothing | ||
# reshape to merge batch dimensions | ||
batch_size = size(q)[3:end] | ||
batch_size == size(k)[3:end] == size(v)[3:end] || | ||
throw(ArgumentError("Batch dimensions have to be the same.")) | ||
q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v)) | ||
# add head dimension | ||
q, k, v = split_heads.((q, k, v), nheads) | ||
# compute attention scores | ||
αt = dot_product_attention_scores(q, k, bias; fdrop) | ||
# move head dimension to third place | ||
vt = permutedims(v, (1, 3, 2, 4)) | ||
xt = vt ⊠ αt | ||
# remove head dimension | ||
x = permutedims(xt, (1, 3, 2, 4)) | ||
x = join_heads(x) | ||
# restore batch dimensions | ||
x = reshape(x, size(x, 1), size(x, 2), batch_size...) | ||
Rₐ = similar(x) | ||
|
||
# lrp pass | ||
## forward: aᵏ ->(v_proj) v ->(attention) x ->(out_proj) out | ||
## lrp: Rᵏ <-(value_rule) Rᵥ <-(AH-Rule) Rₐ <-(out_rule) Rᵏ⁺¹ | ||
## output projection | ||
lrp!( | ||
Rₐ, | ||
rule.out_rule, | ||
mha.projection[1], | ||
modify_layer(rule.out_rule, mha.projection[1]), | ||
x, | ||
Rᵏ⁺¹, | ||
) | ||
## attention | ||
lrp_attention!(Rᵥ, xt, αt, vt, nheads, batch_size, Rₐ) | ||
## value projection | ||
_, _, w = chunk(mha.qkv_layer.weight, 3; dims=1) | ||
_, _, b = chunk(mha.qkv_layer.bias, 3; dims=1) | ||
proj = Dense(w, b) | ||
lrp!(Rᵏ, rule.value_rule, proj, modify_layer(rule.value_rule, proj), aᵏ, Rᵥ) | ||
end | ||
|
||
function lrp_attention!(Rᵥ, x, α, v, nheads, batch_size, Rₐ) | ||
# input dimensions: | ||
## Rₐ: [embedding x token x batch...] | ||
## x : [embedding x token x head x batch] | ||
## α : [token x token x head x batch] | ||
## v : [embedding x token x head x batch] | ||
## Rᵥ: [embedding x token x batch...] | ||
|
||
# reshape Rₐ: combine batch dimensions, split heads, move head dimension | ||
Rₐ = permutedims( | ||
split_heads(reshape(Rₐ, size(Rₐ, 1), size(Rₐ, 2), :), nheads), (1, 3, 2, 4) | ||
) | ||
# compute relevance term | ||
s = Rₐ ./ x | ||
# add extra dimensions for broadcasting | ||
s = reshape(s, size(s, 1), size(s, 2), 1, size(s)[3:end]...) | ||
α = reshape(permutedims(α, (2, 1, 3, 4)), 1, size(α)...) | ||
# compute relevances, broadcasting over extra dimensions | ||
R = α .* s | ||
R = dropdims(sum(R; dims=2); dims=2) | ||
R = R .* v | ||
# reshape relevances (drop extra dimension, move head dimension, join heads, split batch dimensions) | ||
R = join_heads(permutedims(R, (1, 3, 2, 4))) | ||
Rᵥ .= reshape(R, size(R, 1), size(R, 2), batch_size...) | ||
end | ||
|
||
#=========================# | ||
# Special ViT layers # | ||
#=========================# | ||
|
||
# reshaping image -> token | ||
LRP_CONFIG.supports_layer(::typeof(_flatten_spatial)) = true | ||
function lrp!(Rᵏ, ::ZeroRule, ::typeof(_flatten_spatial), _modified_layer, aᵏ, Rᵏ⁺¹) | ||
Rᵏ .= reshape(permutedims(Rᵏ⁺¹, (2, 1, 3)), size(Rᵏ)...) | ||
end | ||
|
||
# ClassToken layer: adds a Class Token; we ignore this token for the relevances | ||
LRP_CONFIG.supports_layer(::ClassTokens) = true | ||
function lrp!(Rᵏ, ::ZeroRule, ::ClassTokens, _modified_layer, aᵏ, Rᵏ⁺¹) | ||
Rᵏ .= Rᵏ⁺¹[:, 2:end, :] | ||
end | ||
|
||
# Positional Embedding (you can also use the PassRule) | ||
LRP_CONFIG.supports_layer(::ViPosEmbedding) = true | ||
is_compatible(::PositionalEmbeddingRule, ::ViPosEmbedding) = true | ||
function lrp!( | ||
Rᵏ, ::PositionalEmbeddingRule, layer::ViPosEmbedding, _modified_layer, aᵏ, Rᵏ⁺¹ | ||
) | ||
Rᵏ .= aᵏ ./ layer(aᵏ) .* Rᵏ⁺¹ | ||
end | ||
|
||
# class token selection: only the class token is used for the final predictions, | ||
# so it gets all the relevance | ||
LRP_CONFIG.supports_layer(::SelectClassToken) = true | ||
function lrp!(Rᵏ, ::ZeroRule, ::SelectClassToken, _modified_layer, aᵏ, Rᵏ⁺¹) | ||
fill!(Rᵏ, zero(eltype(Rᵏ))) | ||
Rᵏ[:, 1, :] .= Rᵏ⁺¹ | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
function prepare_vit(model::ViT) | ||
Maximilian-Stefan-Ernst marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
model = model.layers # remove wrapper type | ||
model = flatten_model(model) # model consists of nested chains | ||
testmode!(model) # make shure there is no dropout during forward pass | ||
if !isa(model[end - 2], typeof(seconddimmean)) | ||
model = Chain(model[1:(end - 3)]..., SelectClassToken(), model[(end - 1):end]...) # swap anonymous function to actual layer | ||
end | ||
return model | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# since package extensions should/can not define new types or functions, | ||
# we have to define them here and add the relevant methods in the extension | ||
|
||
#=========================# | ||
# VisionTransformerExt # | ||
#=========================# | ||
|
||
# layers | ||
"""Flux layer to select the first token, for use with Metalhead.jl's vision transformer.""" | ||
struct SelectClassToken end | ||
Flux.@functor SelectClassToken | ||
|
||
(::SelectClassToken)(x) = x[:, 1, :] | ||
|
||
# rules | ||
Maximilian-Stefan-Ernst marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
""" | ||
SelfAttentionRule(value_rule=ZeroRule(), out_rule=ZeroRule) | ||
LRP-AH rule. Used on `MultiHeadSelfAttention` layers. | ||
Maximilian-Stefan-Ernst marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
# Definition | ||
Propagates relevance ``R^{k+1}`` at layer output to ``R^k`` at layer input according to | ||
```math | ||
R_i^k = \\sum_j\\frac{\\alpha_{ij} a_i^k}{\\sum_l \\alpha_{lj} a_l^k} R_j^{k+1} | ||
``` | ||
where ``alpha_{ij}`` are the attention weights. | ||
Relevance through the value projection (before attention) and the out projection (after attention) is by default propagated using the [`ZeroRule`](@ref). | ||
# Optional arguments | ||
- `value_rule`: Rule for the value projection, defaults to `ZeroRule()` | ||
- `out_rule`: Rule for the out projection, defaults to `ZeroRule()` | ||
# References | ||
- $REF_ALI_TRANSFORMER | ||
""" | ||
struct SelfAttentionRule{V,O} <: AbstractLRPRule | ||
value_rule::V | ||
out_rule::O | ||
end | ||
|
||
""" | ||
PositionalEmbeddingRule() | ||
To be used with Metalhead.jl`s `ViPosEmbedding` layer. Treats the positional embedding like a bias term. | ||
# Definition | ||
Propagates relevance ``R^{k+1}`` at layer output to ``R^k`` at layer input according to | ||
```math | ||
R_i^k = \\frac{a_i^k}{a_i^k + e^i} R_i^{k+1} | ||
``` | ||
where ``e^i`` is the learned positional embedding. | ||
""" | ||
struct PositionalEmbeddingRule <: AbstractLRPRule end | ||
|
||
# utils | ||
"""Prepare the vision transformer model for the use with `RelevancePropagation.jl`""" | ||
function prepare_vit end |
Uh oh!
There was an error while loading. Please reload this page.