Skip to content

Commit ff1f850

Browse files
authored
Add ExplicitImports.jl tests (#21)
1 parent c19c823 commit ff1f850

File tree

8 files changed

+78
-33
lines changed

8 files changed

+78
-33
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ version = "3.0.0-DEV"
55

66
[deps]
77
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
8+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
89
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
910
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
11+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1012
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1113
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1214
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -15,8 +17,10 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1517

1618
[compat]
1719
Flux = "0.14"
20+
MLUtils = "0.4.4"
1821
MacroTools = "0.5"
1922
Markdown = "1"
23+
NNlib = "0.9.24"
2024
Random = "1"
2125
Reexport = "1"
2226
Statistics = "1"

docs/src/developer.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ Using the [Zygote.jl](https://github.com/FluxML/Zygote.jl) AD system,
100100
we obtain the output $z$ of a modified layer and its pullback `back` in a single function call:
101101

102102
```julia
103-
z, back = Zygote.pullback(modified_layer, aᵏ)
103+
z, back = pullback(modified_layer, aᵏ)
104104
```
105105
We then call the pullback with the vector $s$ to obtain $c$:
106106
```julia
@@ -201,7 +201,7 @@ function lrp!(Rᵏ, rule, layer, modified_layer, aᵏ, Rᵏ⁺¹)
201201
layer = isnothing(modified_layer) ? layer : modified_layer
202202

203203
ãᵏ = modify_input(rule, aᵏ)
204-
z, back = Zygote.pullback(modified_layer, ãᵏ)
204+
z, back = pullback(modified_layer, ãᵏ)
205205
s = Rᵏ⁺¹ ./ modify_denominator(rule, z)
206206
Rᵏ .= ãᵏ .* only(back(s))
207207
end

src/RelevancePropagation.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
11
module RelevancePropagation
22

3-
using Reexport
4-
@reexport using XAIBase
3+
using Base.Iterators
4+
using Reexport: @reexport
55
import XAIBase: call_analyzer
6+
using XAIBase: XAIBase, AbstractXAIMethod, Explanation
7+
using XAIBase: AbstractOutputSelector, AbstractFeatureSelector, number_of_features
68

7-
using XAIBase: AbstractFeatureSelector, number_of_features
8-
using Base.Iterators
99
using MacroTools: @forward
10-
using Flux
11-
using Flux: Scale, normalise
12-
using Zygote
13-
using Markdown
10+
using Flux: Flux, Chain, Parallel, SkipConnection
11+
using Flux: Dense, Conv, ConvTranspose, CrossCor
12+
using Flux: BatchNorm, GroupNorm, InstanceNorm, LayerNorm, Scale
13+
using Flux:
14+
MaxPool, MeanPool, AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool
15+
using Flux: AlphaDropout, Dropout, dropout
16+
using NNlib: relu, gelu, swish, mish, softmax, softmax!
17+
using MLUtils: MLUtils
18+
19+
using Zygote: pullback
20+
using Markdown: @md_str
1421
using Statistics: mean, std
1522

23+
@reexport using XAIBase
24+
1625
include("bibliography.jl")
1726
include("layer_types.jl")
1827
include("layer_utils.jl")

src/layer_types.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ const DataflowLayer = Union{Chain,Parallel,SkipConnection}
55
const ConvLayer = Union{Conv,ConvTranspose,CrossCor}
66

77
"""Union type for dropout layers."""
8-
const DropoutLayer = Union{Dropout,typeof(Flux.dropout),AlphaDropout}
8+
const DropoutLayer = Union{Dropout,typeof(dropout),AlphaDropout}
99

1010
"""Union type for reshaping layers such as `flatten`."""
11-
const ReshapingLayer = Union{typeof(Flux.flatten),typeof(Flux.MLUtils.flatten)}
11+
const ReshapingLayer = Union{typeof(Flux.flatten),typeof(MLUtils.flatten)}
1212

1313
"""Union type for max pooling layers."""
1414
const MaxPoolLayer = Union{MaxPool,AdaptiveMaxPool,GlobalMaxPool}

src/rules.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ const LRP_DEFAULT_BETA = 1.0f0
1212
function lrp!(Rᵏ, rule::AbstractLRPRule, layer, modified_layer, aᵏ, Rᵏ⁺¹)
1313
layer = isnothing(modified_layer) ? layer : modified_layer
1414
ãᵏ = modify_input(rule, aᵏ)
15-
z, back = Zygote.pullback(layer, ãᵏ)
15+
z, back = pullback(layer, ãᵏ)
1616
s = Rᵏ⁺¹ ./ modify_denominator(rule, z)
1717
c = only(back(s))
1818
Rᵏ .= ãᵏ .* c
@@ -338,9 +338,9 @@ function lrp!(Rᵏ, rule::ZBoxRule, layer, modified_layers, aᵏ, Rᵏ⁺¹)
338338
l = zbox_input(aᵏ, rule.low)
339339
h = zbox_input(aᵏ, rule.high)
340340

341-
z, back = Zygote.pullback(layer, aᵏ)
342-
z⁺, back⁺ = Zygote.pullback(modified_layers.layer⁺, l)
343-
z⁻, back⁻ = Zygote.pullback(modified_layers.layer⁻, h)
341+
z, back = pullback(layer, aᵏ)
342+
z⁺, back⁺ = pullback(modified_layers.layer⁺, l)
343+
z⁻, back⁻ = pullback(modified_layers.layer⁻, h)
344344

345345
s = Rᵏ⁺¹ ./ modify_denominator(rule, z - z⁺ - z⁻)
346346
c = only(back(s))
@@ -402,8 +402,8 @@ function lrp!(Rᵏ, rule::AlphaBetaRule, _layer, modified_layers, aᵏ, Rᵏ⁺
402402
aᵏ⁺ = keep_positive(aᵏ)
403403
aᵏ⁻ = keep_negative(aᵏ)
404404

405-
zᵅ⁺, back⁺ = Zygote.pullback(modified_layers.layerᵅ⁺, aᵏ⁺)
406-
zᵅ⁻, back⁻ = Zygote.pullback(modified_layers.layerᵅ⁻, aᵏ⁻)
405+
zᵅ⁺, back⁺ = pullback(modified_layers.layerᵅ⁺, aᵏ⁺)
406+
zᵅ⁻, back⁻ = pullback(modified_layers.layerᵅ⁻, aᵏ⁻)
407407
# No need to linearize again: Wᵝ⁺ = Wᵅ⁺ and Wᵝ⁻ = Wᵅ⁻
408408
zᵝ⁺ = modified_layers.layerᵝ⁺(aᵏ⁻)
409409
zᵝ⁻ = modified_layers.layerᵝ⁻(aᵏ⁺)
@@ -451,8 +451,8 @@ function lrp!(Rᵏ, rule::ZPlusRule, _layer, modified_layers, aᵏ, Rᵏ⁺¹)
451451
aᵏ⁺ = keep_positive(aᵏ)
452452
aᵏ⁻ = keep_negative(aᵏ)
453453

454-
z⁺, back⁺ = Zygote.pullback(modified_layers.layer⁺, aᵏ⁺)
455-
z⁻, back⁻ = Zygote.pullback(modified_layers.layer⁻, aᵏ⁻)
454+
z⁺, back⁺ = pullback(modified_layers.layer⁺, aᵏ⁺)
455+
z⁻, back⁻ = pullback(modified_layers.layer⁻, aᵏ⁻)
456456

457457
s = Rᵏ⁺¹ ./ modify_denominator(rule, z⁺ + z⁻)
458458
c⁺ = only(back⁺(s))
@@ -504,8 +504,8 @@ function lrp!(Rᵏ, rule::GeneralizedGammaRule, layer, modified_layers, aᵏ, R
504504
aᵏ⁺ = keep_positive(aᵏ)
505505
aᵏ⁻ = keep_negative(aᵏ)
506506

507-
zˡ⁺, back⁺ = Zygote.pullback(modified_layers.layerˡ⁺, aᵏ⁺)
508-
zˡ⁻, back⁻ = Zygote.pullback(modified_layers.layerˡ⁻, aᵏ⁻)
507+
zˡ⁺, back⁺ = pullback(modified_layers.layerˡ⁺, aᵏ⁺)
508+
zˡ⁻, back⁻ = pullback(modified_layers.layerˡ⁻, aᵏ⁻)
509509
# No need to linearize again: Wˡ⁺ = Wʳ⁺ and Wˡ⁻ = Wʳ⁻
510510
zʳ⁺ = modified_layers.layerʳ⁺(aᵏ⁻)
511511
zʳ⁻ = modified_layers.layerʳ⁻(aᵏ⁺)

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
4+
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
45
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
56
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
67
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"

test/runtests.jl

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,9 @@ using Aqua
66

77
@testset verbose = true "RelevancePropagation.jl" begin
88
@testset verbose = true "Linting" begin
9-
@testset "Code formatting" begin
10-
@info "- Testing code formatting with JuliaFormatter..."
11-
@test JuliaFormatter.format(
12-
RelevancePropagation; verbose=false, overwrite=false
13-
)
14-
end
15-
@testset "Aqua.jl" begin
16-
@info "- Running Aqua.jl tests. These might print warnings from dependencies..."
17-
Aqua.test_all(RelevancePropagation; ambiguities=false)
18-
end
9+
@info "Testing linting..."
10+
include("test_linting.jl")
1911
end
20-
2112
@testset "Utilities" begin
2213
@info "Testing utilities..."
2314
include("test_utils.jl")

test/test_linting.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using RelevancePropagation
2+
using Test
3+
4+
using JuliaFormatter: JuliaFormatter
5+
using Aqua: Aqua
6+
using ExplicitImports:
7+
check_no_implicit_imports,
8+
check_no_stale_explicit_imports,
9+
check_all_explicit_imports_via_owners,
10+
check_all_qualified_accesses_via_owners,
11+
check_no_self_qualified_accesses
12+
13+
@testset "Code formatting" begin
14+
@info "...with JuliaFormatter.jl"
15+
@test JuliaFormatter.format(RelevancePropagation; verbose=false, overwrite=false)
16+
end
17+
18+
@testset "Aqua.jl" begin
19+
@info "...with Aqua.jl"
20+
Aqua.test_all(RelevancePropagation; ambiguities=false)
21+
end
22+
23+
@testset "ExplicitImports tests" begin
24+
@info "...with ExplicitImports.jl"
25+
@testset "Improper implicit imports" begin
26+
@test check_no_implicit_imports(RelevancePropagation) === nothing
27+
end
28+
@testset "Improper explicit imports" begin
29+
@test check_no_stale_explicit_imports(RelevancePropagation;) === nothing
30+
@test check_all_explicit_imports_via_owners(RelevancePropagation) === nothing
31+
# TODO: test in the future when `public` is more common
32+
# @test check_all_explicit_imports_are_public(RelevancePropagation) === nothing
33+
end
34+
@testset "Improper qualified accesses" begin
35+
@test check_all_qualified_accesses_via_owners(RelevancePropagation) === nothing
36+
@test check_no_self_qualified_accesses(RelevancePropagation) === nothing
37+
# TODO: test in the future when `public` is more common
38+
# @test check_all_qualified_accesses_are_public(RelevancePropagation) === nothing
39+
end
40+
end

0 commit comments

Comments
 (0)