Skip to content

Commit a09f5ee

Browse files
authored
Make Tullio weak dependency using package extensions (#141)
* Make Tullio weak dependency using package extensions. For backwards compatibility, it is still a normal dependency on Julia versions prior to 1.9. * Skip Aqua.jl's Project.toml tests on Julia 1.6
1 parent e52bde8 commit a09f5ee

File tree

7 files changed

+61
-18
lines changed

7 files changed

+61
-18
lines changed

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1616
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1717
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1818

19+
[weakdeps]
20+
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
21+
22+
[extensions]
23+
TullioLRPRulesExt = "Tullio"
24+
1925
[compat]
2026
ColorSchemes = "3.18"
2127
Distributions = "0.25"

ext/TullioLRPRulesExt.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module TullioLRPRulesExt
2+
3+
using ExplainableAI, Flux, Tullio
4+
import ExplainableAI: lrp!, modify_input, modify_denominator
5+
import ExplainableAI: ZeroRule, EpsilonRule, GammaRule, WSquareRule
6+
7+
# Fast implementation for Dense layer using Tullio.jl's einsum notation:
8+
for R in (ZeroRule, EpsilonRule, GammaRule)
9+
@eval function lrp!(Rᵏ, rule::$R, layer::Dense, modified_layer, aᵏ, Rᵏ⁺¹)
10+
layer = isnothing(modified_layer) ? layer : modified_layer
11+
ãᵏ = modify_input(rule, aᵏ)
12+
z = modify_denominator(rule, layer(ãᵏ))
13+
@tullio Rᵏ[j, b] = layer.weight[i, j] * ãᵏ[j, b] / z[i, b] * Rᵏ⁺¹[i, b]
14+
end
15+
end
16+
17+
function lrp!(Rᵏ, ::WSquareRule, _layer::Dense, modified_layer::Dense, aᵏ, Rᵏ⁺¹)
18+
den = sum(modified_layer.weight; dims=2)
19+
@tullio Rᵏ[j, b] = modified_layer.weight[i, j] / den[i] * Rᵏ⁺¹[i, b]
20+
end
21+
end # module

src/ExplainableAI.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ using Distributions: Distribution, Sampleable, Normal
66
using Random: AbstractRNG, GLOBAL_RNG
77
using Flux
88
using Zygote
9-
using Tullio
109
using Markdown
1110

1211
# Heatmapping:
@@ -68,4 +67,12 @@ export heatmap
6867
# utils
6968
export strip_softmax, flatten_model, canonize
7069
export preprocess_imagenet
70+
71+
# Package extension backwards compatibility with Julia 1.6.
72+
# For Julia 1.6, Tullio is treated as a normal dependency and always loaded.
73+
# https://pkgdocs.julialang.org/v1/creating-packages/#Transition-from-normal-dependency-to-extension
74+
if !isdefined(Base, :get_extension)
75+
include("../ext/TullioLRPRulesExt.jl")
76+
end
77+
7178
end # module

src/lrp/rules.jl

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -540,17 +540,5 @@ function lrp!(Rᵏ, _rule::FlatRule, _layer::Dense, _modified_layer, _aᵏ, Rᵏ
540540
end
541541
end
542542

543-
# Fast implementation for Dense layer using Tullio.jl's einsum notation:
544-
for R in (ZeroRule, EpsilonRule, GammaRule)
545-
@eval function lrp!(Rᵏ, rule::$R, layer::Dense, modified_layer, aᵏ, Rᵏ⁺¹)
546-
layer = isnothing(modified_layer) ? layer : modified_layer
547-
ãᵏ = modify_input(rule, aᵏ)
548-
z = modify_denominator(rule, layer(ãᵏ))
549-
@tullio Rᵏ[j, b] = layer.weight[i, j] * ãᵏ[j, b] / z[i, b] * Rᵏ⁺¹[i, b]
550-
end
551-
end
552-
553-
function lrp!(Rᵏ, ::WSquareRule, _layer::Dense, modified_layer::Dense, aᵏ, Rᵏ⁺¹)
554-
den = sum(modified_layer.weight; dims=2)
555-
@tullio Rᵏ[j, b] = modified_layer.weight[i, j] / den[i] * Rᵏ⁺¹[i, b]
556-
end
543+
# Fast implementations for Dense layers can be conditionally loaded with Tullio.jl
544+
# using package extensions, see exp/TullioRulesExt.jl

test/Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
1212
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1313
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
14+
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1415

1516
[compat]
17+
Aqua = "0.7"
1618
Distributions = "0.25"
1719
Flux = "0.13, 0.14"
1820
ImageCore = "0.9, 0.10"
1921
JLD2 = "0.4"
2022
LoopVectorization = "0.12"
2123
Metalhead = "0.8"
22-
ReferenceTests = "0.10.1"
24+
ReferenceTests = "0.10"
2325
Suppressor = "0.2"
26+
Tullio = "0.3"

test/runtests.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@ using Random
99
pseudorand(dims...) = rand(MersenneTwister(123), Float32, dims...)
1010

1111
@testset "ExplainableAI.jl" begin
12-
# Run Aqua.jl quality assurance tests
1312
@testset "Aqua.jl" begin
1413
@info "Running Aqua.jl's auto quality assurance tests. These might print warnings from dependencies."
15-
Aqua.test_all(ExplainableAI; ambiguities=false)
14+
# Package extensions break Project.toml formatting tests on Julia 1.6
15+
# https://github.com/JuliaTesting/Aqua.jl/issues/105
16+
Aqua.test_all(
17+
ExplainableAI; ambiguities=false, project_toml_formatting=VERSION >= v"1.7"
18+
)
1619
end
1720

1821
# Run package tests

test/test_rules.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,21 @@ layers = Dict(
238238
end
239239
end
240240

241+
# Test with loaded package extension
242+
using LoopVectorization
243+
using Tullio
244+
@testset "Dense Tullio" begin
245+
for (rulename, rule) in RULES
246+
@testset "$rulename" begin
247+
for (layername, layer) in layers
248+
@testset "$layername" begin
249+
run_rule_tests(rule, layer, rulename, layername, aᵏ_dense)
250+
end
251+
end
252+
end
253+
end
254+
end
255+
241256
## Test ConvLayers and others
242257
cin, cout = 3, 4
243258
insize = (6, 6, 3, batchsize)

0 commit comments

Comments
 (0)