diff --git a/Project.toml b/Project.toml index e95fdfe..ec04eec 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[weakdeps] +Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" + +[extensions] +RelevancePropagationMetalheadExt = ["Metalhead"] + [compat] Flux = "0.13, 0.14" MacroTools = "0.5" @@ -23,3 +29,4 @@ Statistics = "1" XAIBase = "3" Zygote = "0.6" julia = "1.6" +Metalhead = "0.9" \ No newline at end of file diff --git a/ext/RelevancePropagationMetalheadExt/RelevancePropagationMetalheadExt.jl b/ext/RelevancePropagationMetalheadExt/RelevancePropagationMetalheadExt.jl new file mode 100644 index 0000000..a83a7f1 --- /dev/null +++ b/ext/RelevancePropagationMetalheadExt/RelevancePropagationMetalheadExt.jl @@ -0,0 +1,12 @@ +module RelevancePropagationMetalheadExt +using RelevancePropagation, Flux +using RelevancePropagation: + SelfAttentionRule, SelectClassToken, PositionalEmbeddingRule, modify_layer # all used types have to be used explicitely +import RelevancePropagation: canonize, is_compatible, lrp! # all functions to which you want to add methods have to be imported +using Metalhead: + ViPosEmbedding, ClassTokens, MultiHeadSelfAttention, chunk, ViT, seconddimmean +using Metalhead.Layers: _flatten_spatial + +include("utils.jl") +include("rules.jl") +end diff --git a/ext/RelevancePropagationMetalheadExt/rules.jl b/ext/RelevancePropagationMetalheadExt/rules.jl new file mode 100644 index 0000000..561916d --- /dev/null +++ b/ext/RelevancePropagationMetalheadExt/rules.jl @@ -0,0 +1,114 @@ +# attention +SelfAttentionRule() = SelfAttentionRule(ZeroRule(), ZeroRule()) +LRP_CONFIG.supports_layer(::MultiHeadSelfAttention) = true +is_compatible(::SelfAttentionRule, ::MultiHeadSelfAttention) = true + +function lrp!( + Rᵏ, rule::SelfAttentionRule, mha::MultiHeadSelfAttention, _modified_layer, aᵏ, Rᵏ⁺¹ +) + # query, key, value projections + qkv = mha.qkv_layer(aᵏ) + q, k, v = chunk(qkv, 3; dims=1) + Rᵥ = similar(v) + # attention + nheads = mha.nheads + fdrop = mha.attn_drop + bias = nothing + # reshape to merge batch dimensions + batch_size = size(q)[3:end] + batch_size == size(k)[3:end] == size(v)[3:end] || + throw(ArgumentError("Batch dimensions have to be the same.")) + q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v)) + # add head dimension + q, k, v = split_heads.((q, k, v), nheads) + # compute attention scores + αt = dot_product_attention_scores(q, k, bias; fdrop) + # move head dimension to third place + vt = permutedims(v, (1, 3, 2, 4)) + xt = vt ⊠ αt + # remove head dimension + x = permutedims(xt, (1, 3, 2, 4)) + x = join_heads(x) + # restore batch dimensions + x = reshape(x, size(x, 1), size(x, 2), batch_size...) + Rₐ = similar(x) + + # lrp pass + ## forward: aᵏ ->(v_proj) v ->(attention) x ->(out_proj) out + ## lrp: Rᵏ <-(value_rule) Rᵥ <-(AH-Rule) Rₐ <-(out_rule) Rᵏ⁺¹ + ## output projection + lrp!( + Rₐ, + rule.out_rule, + mha.projection[1], + modify_layer(rule.out_rule, mha.projection[1]), + x, + Rᵏ⁺¹, + ) + ## attention + lrp_attention!(Rᵥ, xt, αt, vt, nheads, batch_size, Rₐ) + ## value projection + _, _, w = chunk(mha.qkv_layer.weight, 3; dims=1) + _, _, b = chunk(mha.qkv_layer.bias, 3; dims=1) + proj = Dense(w, b) + lrp!(Rᵏ, rule.value_rule, proj, modify_layer(rule.value_rule, proj), aᵏ, Rᵥ) +end + +function lrp_attention!(Rᵥ, x, α, v, nheads, batch_size, Rₐ) + # input dimensions: + ## Rₐ: [embedding x token x batch...] + ## x : [embedding x token x head x batch] + ## α : [token x token x head x batch] + ## v : [embedding x token x head x batch] + ## Rᵥ: [embedding x token x batch...] + + # reshape Rₐ: combine batch dimensions, split heads, move head dimension + Rₐ = permutedims( + split_heads(reshape(Rₐ, size(Rₐ, 1), size(Rₐ, 2), :), nheads), (1, 3, 2, 4) + ) + # compute relevance term + s = Rₐ ./ x + # add extra dimensions for broadcasting + s = reshape(s, size(s, 1), size(s, 2), 1, size(s)[3:end]...) + α = reshape(permutedims(α, (2, 1, 3, 4)), 1, size(α)...) + # compute relevances, broadcasting over extra dimensions + R = α .* s + R = dropdims(sum(R; dims=2); dims=2) + R = R .* v + # reshape relevances (drop extra dimension, move head dimension, join heads, split batch dimensions) + R = join_heads(permutedims(R, (1, 3, 2, 4))) + Rᵥ .= reshape(R, size(R, 1), size(R, 2), batch_size...) +end + +#=========================# +# Special ViT layers # +#=========================# + +# reshaping image -> token +LRP_CONFIG.supports_layer(::typeof(_flatten_spatial)) = true +function lrp!(Rᵏ, ::ZeroRule, ::typeof(_flatten_spatial), _modified_layer, aᵏ, Rᵏ⁺¹) + Rᵏ .= reshape(permutedims(Rᵏ⁺¹, (2, 1, 3)), size(Rᵏ)...) +end + +# ClassToken layer: adds a Class Token; we ignore this token for the relevances +LRP_CONFIG.supports_layer(::ClassTokens) = true +function lrp!(Rᵏ, ::ZeroRule, ::ClassTokens, _modified_layer, aᵏ, Rᵏ⁺¹) + Rᵏ .= Rᵏ⁺¹[:, 2:end, :] +end + +# Positional Embedding (you can also use the PassRule) +LRP_CONFIG.supports_layer(::ViPosEmbedding) = true +is_compatible(::PositionalEmbeddingRule, ::ViPosEmbedding) = true +function lrp!( + Rᵏ, ::PositionalEmbeddingRule, layer::ViPosEmbedding, _modified_layer, aᵏ, Rᵏ⁺¹ +) + Rᵏ .= aᵏ ./ layer(aᵏ) .* Rᵏ⁺¹ +end + +# class token selection: only the class token is used for the final predictions, +# so it gets all the relevance +LRP_CONFIG.supports_layer(::SelectClassToken) = true +function lrp!(Rᵏ, ::ZeroRule, ::SelectClassToken, _modified_layer, aᵏ, Rᵏ⁺¹) + fill!(Rᵏ, zero(eltype(Rᵏ))) + Rᵏ[:, 1, :] .= Rᵏ⁺¹ +end diff --git a/ext/RelevancePropagationMetalheadExt/utils.jl b/ext/RelevancePropagationMetalheadExt/utils.jl new file mode 100644 index 0000000..f487cac --- /dev/null +++ b/ext/RelevancePropagationMetalheadExt/utils.jl @@ -0,0 +1,14 @@ +function canonize(model::ViT) + model = model.layers # remove wrapper type + model = flatten_model(model) # model consists of nested chains + testmode!(model) # make shure there is no dropout during forward pass + if !isa(model[end - 2], typeof(seconddimmean)) + model = Chain(model[1:(end - 3)]..., SelectClassToken(), model[(end - 1):end]...) # swap anonymous function to actual layer + end + return canonize(model) +end + +# these are originally from NNlib.jl, but since they are unexported, we don't want +# to rely on them an re-define them here +split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...) +join_heads(x) = reshape(x, :, size(x)[3:end]...) diff --git a/src/RelevancePropagation.jl b/src/RelevancePropagation.jl index b914dc5..d293525 100644 --- a/src/RelevancePropagation.jl +++ b/src/RelevancePropagation.jl @@ -36,7 +36,7 @@ export AbstractLRPRule export LRP_CONFIG export ZeroRule, EpsilonRule, GammaRule, WSquareRule, FlatRule export PassRule, ZBoxRule, ZPlusRule, AlphaBetaRule, GeneralizedGammaRule -export LayerNormRule +export LayerNormRule, PositionalEmbeddingRule, SelfAttentionRule # LRP composites export Composite, AbstractCompositePrimitive diff --git a/src/rules.jl b/src/rules.jl index 72bda57..22402fc 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -593,3 +593,55 @@ function lrp!(Rᵏ, _rule::FlatRule, _layer::Dense, _modified_layer, _aᵏ, Rᵏ fill!(view(Rᵏ, :, i), sum(view(Rᵏ⁺¹, :, i)) / n) end end + +#============# +# Extensions # +#============# +# since package extensions should/can not define new types or functions, +# we have to define them here and add the relevant methods in the extension + +# RelevancePropagationMetalheadExt + +"""Flux layer to select the first token, for use with Metalhead.jl's vision transformer.""" +struct SelectClassToken end +Flux.@functor SelectClassToken +(::SelectClassToken)(x) = x[:, 1, :] + +""" + SelfAttentionRule(value_rule=ZeroRule(), out_rule=ZeroRule) + +LRP-AH rule. Used on `MultiHeadSelfAttention` layers from Metalhead.jl. Metalhead.jl has to be loaded to make use of this Rule. + +# Definition +Propagates relevance ``R^{k+1}`` at layer output to ``R^k`` at layer input according to +```math +R_i^k = \\sum_j\\frac{\\alpha_{ij} a_i^k}{\\sum_l \\alpha_{lj} a_l^k} R_j^{k+1} +``` +where ``alpha_{ij}`` are the attention weights. +Relevance through the value projection (before attention) and the out projection (after attention) is by default propagated using the [`ZeroRule`](@ref). + +# Optional arguments +- `value_rule`: Rule for the value projection, defaults to `ZeroRule()` +- `out_rule`: Rule for the out projection, defaults to `ZeroRule()` + +# References +- $REF_ALI_TRANSFORMER +""" +struct SelfAttentionRule{V,O} <: AbstractLRPRule + value_rule::V + out_rule::O +end + +""" + PositionalEmbeddingRule() + +To be used with Metalhead.jl`s `ViPosEmbedding` layer. Treats the positional embedding like a bias term. Metalhead.jl has to be loaded to make use of this Rule. + +# Definition +Propagates relevance ``R^{k+1}`` at layer output to ``R^k`` at layer input according to +```math +R_i^k = \\frac{a_i^k}{a_i^k + e^i} R_i^{k+1} +``` +where ``e^i`` is the learned positional embedding. +""" +struct PositionalEmbeddingRule <: AbstractLRPRule end