Skip to content

Commit a86c962

Browse files
authored
Modularize package tests (#17)
* Modularize tests * Fix type inferrability * Fix Flux GPU URL
1 parent 195da29 commit a86c962

File tree

13 files changed

+66
-27
lines changed

13 files changed

+66
-27
lines changed

docs/src/literate/basics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ expl.extras.layerwise_relevances
142142
# ## Performance tips
143143
# ### Using LRP with a GPU
144144
# All LRP analyzers support GPU backends,
145-
# building on top of [Flux.jl's GPU support](https://fluxml.ai/Flux.jl/stable/gpu/).
145+
# building on top of [Flux.jl's GPU support](https://fluxml.ai/Flux.jl/stable/guide/gpu/).
146146
# Using a GPU only requires moving the input array and model weights to the GPU.
147147
#
148148
# For example, using [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl):

src/rules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ function lrp!(Rᵏ, rule::AbstractLRPRule, layer, modified_layer, aᵏ, Rᵏ⁺
1616
s = Rᵏ⁺¹ ./ modify_denominator(rule, z)
1717
c = only(back(s))
1818
Rᵏ .= ãᵏ .* c
19+
return Rᵏ
1920
end
2021

2122
#===================================#

test/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@ ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
1111
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1212
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1313
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
14-

test/runtests.jl

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
1-
using Test
2-
using ReferenceTests
3-
using Aqua
4-
using JuliaFormatter
5-
using Random
6-
71
using RelevancePropagation
8-
using Flux
9-
import Flux: Scale
102

11-
pseudorand(dims...) = rand(MersenneTwister(123), Float32, dims...)
3+
using Test
4+
using JuliaFormatter
5+
using Aqua
126

137
@testset "RelevancePropagation.jl" begin
14-
@testset "Aqua.jl" begin
15-
@info "Running Aqua.jl's auto quality assurance tests. These might print warnings from dependencies."
16-
Aqua.test_all(RelevancePropagation; ambiguities=false)
17-
end
18-
@testset "JuliaFormatter.jl" begin
19-
@info "Running JuliaFormatter's code formatting tests."
20-
@test format(RelevancePropagation; verbose=false, overwrite=false)
8+
if VERSION >= v"1.10"
9+
@testset "Code formatting" begin
10+
@info "- Testing code formatting with JuliaFormatter..."
11+
@test JuliaFormatter.format(
12+
RelevancePropagation; verbose=false, overwrite=false
13+
)
14+
end
15+
@testset "Aqua.jl" begin
16+
@info "- Running Aqua.jl tests. These might print warnings from dependencies..."
17+
Aqua.test_all(RelevancePropagation; ambiguities=false)
18+
end
2119
end
20+
2221
@testset "Utilities" begin
2322
@info "Testing utilities..."
2423
include("test_utils.jl")

test/test_batches.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
using Flux
21
using RelevancePropagation
3-
using Random
2+
using Test
3+
4+
using Flux
5+
using Random: rand, MersenneTwister
6+
7+
pseudorand(dims...) = rand(MersenneTwister(123), Float32, dims...)
48

59
## Test `fuse_batchnorm` on Dense and Conv layers
610
ins = 20

test/test_canonize.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
using RelevancePropagation
2+
using Test
3+
14
using Flux
25
using Flux: flatten, Scale
3-
using RelevancePropagation
46
using RelevancePropagation: canonize_fuse
57
using Random
68

9+
pseudorand(dims...) = rand(MersenneTwister(123), Float32, dims...)
10+
711
batchsize = 50
812

913
##=====================================#

test/test_chain_utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
using RelevancePropagation
2+
using Test
3+
14
using RelevancePropagation: ChainTuple, ParallelTuple, SkipConnectionTuple
25
using RelevancePropagation: ModelIndex, chainmap, chainindices, chainzip
36
using RelevancePropagation: activation_fn

test/test_checks.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
using RelevancePropagation
2+
using Test
3+
using ReferenceTests
4+
25
using RelevancePropagation: check_lrp_compat, print_lrp_model_check
36
using Suppressor
7+
48
err = ErrorException("Unknown layer or activation function found in model")
59

610
# Flux layers

test/test_cnn.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
using RelevancePropagation
2+
using Test
3+
using ReferenceTests
4+
25
using Flux
36
using JLD2
7+
using Random: rand, MersenneTwister
48

59
const LRP_ANALYZERS = Dict(
610
"LRPZero" => LRP,
711
"LRPZero_COC" => m -> LRP(m; flatten=false), # chain of chains
812
"LRPEpsilonAlpha2Beta1Flat" => m -> LRP(m, EpsilonAlpha2Beta1Flat()),
913
)
1014

15+
pseudorand(dims...) = rand(MersenneTwister(123), Float32, dims...)
16+
1117
input_size = (32, 32, 3, 1)
1218
input = pseudorand(input_size)
1319

test/test_composite.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
using RelevancePropagation
2+
using Test
3+
using ReferenceTests
4+
5+
using NNlib
6+
using Flux
7+
using Flux: flatten, Scale
28
using Metalhead
3-
using Flux, NNlib
49

510
model = VGG(11; pretrain=false).layers
611
model_flat = flatten_model(model)

0 commit comments

Comments
 (0)