Skip to content

Commit 699a3b8

Browse files
authored
Fast WSquareRule on Dense layers (#98)
1 parent d26a833 commit 699a3b8

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/lrp/rules.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,3 +452,8 @@ function lrp!(Rₖ, ::FlatRule, layer::Dense, aₖ, Rₖ₊₁)
452452
end
453453
return nothing
454454
end
455+
function lrp!(Rₖ, ::WSquareRule, layer::Dense, aₖ, Rₖ₊₁)
456+
den = sum(layer.weight .^ 2; dims=2)
457+
@tullio Rₖ[j, b] = layer.weight[i, j]^2 / den[i] * Rₖ₊₁[i, b]
458+
return nothing
459+
end

0 commit comments

Comments
 (0)