Skip to content

Commit 12046e5

Browse files
authored
Update benchmarks and test them with PkgJogger (#175)
1 parent cf65291 commit 12046e5

File tree

7 files changed

+63
-96
lines changed

7 files changed

+63
-96
lines changed

benchmark/Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
33
ExplainableAI = "4f1bc3e1-d60d-4ed0-9367-9bdff9846d3b"
44
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
5-
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
65
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
7-
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
6+
PkgJogger = "10150987-6cc1-4b76-abee-b1c1cbd91c01"
87

98
[compat]
109
BenchmarkTools = "1"

benchmark/bench_jogger.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
using BenchmarkTools
2+
using Flux
3+
using ExplainableAI
4+
5+
on_CI = haskey(ENV, "GITHUB_ACTIONS")
6+
7+
T = Float32
8+
input_size = (32, 32, 3, 1)
9+
input = rand(T, input_size)
10+
11+
model = Chain(
12+
Chain(
13+
Conv((3, 3), 3 => 8, relu; pad=1),
14+
Conv((3, 3), 8 => 8, relu; pad=1),
15+
MaxPool((2, 2)),
16+
Conv((3, 3), 8 => 16, relu; pad=1),
17+
Conv((3, 3), 16 => 16, relu; pad=1),
18+
MaxPool((2, 2)),
19+
),
20+
Chain(
21+
Flux.flatten,
22+
Dense(1024 => 512, relu), # 102_764_544 parameters
23+
Dropout(0.5),
24+
Dense(512 => 100, relu),
25+
),
26+
)
27+
Flux.testmode!(model, true)
28+
29+
# Use one representative algorithm of each type
30+
METHODS = Dict(
31+
"Gradient" => Gradient,
32+
"InputTimesGradient" => InputTimesGradient,
33+
"SmoothGrad" => model -> SmoothGrad(model, 5),
34+
"IntegratedGradients" => model -> IntegratedGradients(model, 5),
35+
)
36+
37+
# Define benchmark
38+
construct(method, model) = method(model) # for use with @benchmarkable macro
39+
40+
suite = BenchmarkGroup()
41+
suite["CNN"] = BenchmarkGroup([k for k in keys(METHODS)])
42+
for (name, method) in METHODS
43+
analyzer = method(model)
44+
suite["CNN"][name] = BenchmarkGroup(["construct analyzer", "analyze"])
45+
suite["CNN"][name]["constructor"] = @benchmarkable construct($(method), $(model))
46+
suite["CNN"][name]["analyze"] = @benchmarkable analyze($(input), $(analyzer))
47+
end

benchmark/benchmarks.jl

Lines changed: 4 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,5 @@
1-
using BenchmarkTools
2-
using LoopVectorization
3-
using Tullio
4-
using Flux
1+
using PkgJogger
52
using ExplainableAI
6-
using ExplainableAI: lrp!, modify_layer
7-
8-
on_CI = haskey(ENV, "GITHUB_ACTIONS")
9-
10-
T = Float32
11-
input_size = (32, 32, 3, 1)
12-
input = rand(T, input_size)
13-
14-
model = Chain(
15-
Chain(
16-
Conv((3, 3), 3 => 8, relu; pad=1),
17-
Conv((3, 3), 8 => 8, relu; pad=1),
18-
MaxPool((2, 2)),
19-
Conv((3, 3), 8 => 16, relu; pad=1),
20-
Conv((3, 3), 16 => 16, relu; pad=1),
21-
MaxPool((2, 2)),
22-
),
23-
Chain(
24-
Flux.flatten,
25-
Dense(1024 => 512, relu), # 102_764_544 parameters
26-
Dropout(0.5),
27-
Dense(512 => 100, relu),
28-
),
29-
)
30-
Flux.testmode!(model, true)
31-
32-
# Use one representative algorithm of each type
33-
algs = Dict(
34-
"Gradient" => Gradient,
35-
"InputTimesGradient" => InputTimesGradient,
36-
"LRP" => LRP,
37-
"LREpsilonPlusFlat" => model -> LRP(model, EpsilonPlusFlat()),
38-
"SmoothGrad" => model -> SmoothGrad(model, 5),
39-
"IntegratedGradients" => model -> IntegratedGradients(model, 5),
40-
)
41-
42-
# Define benchmark
43-
_alg(alg, model) = alg(model) # for use with @benchmarkable macro
44-
45-
SUITE = BenchmarkGroup()
46-
SUITE["CNN"] = BenchmarkGroup([k for k in keys(algs)])
47-
for (name, alg) in algs
48-
analyzer = alg(model)
49-
SUITE["CNN"][name] = BenchmarkGroup(["construct analyzer", "analyze"])
50-
SUITE["CNN"][name]["construct analyzer"] = @benchmarkable _alg($(alg), $(model))
51-
SUITE["CNN"][name]["analyze"] = @benchmarkable analyze($(input), $(analyzer))
52-
end
53-
54-
# generate input for conv layers
55-
insize = (32, 32, 3, 1)
56-
in_dense = 64
57-
out_dense = 10
58-
aᵏ = rand(T, insize)
59-
60-
layers = Dict(
61-
"Conv" => (Conv((3, 3), 3 => 2), aᵏ),
62-
"Dense" => (Dense(in_dense, out_dense, relu), randn(T, in_dense, 1)),
63-
)
64-
rules = Dict(
65-
"ZeroRule" => ZeroRule(),
66-
"EpsilonRule" => EpsilonRule(),
67-
"GammaRule" => GammaRule(),
68-
"WSquareRule" => WSquareRule(),
69-
"FlatRule" => FlatRule(),
70-
"AlphaBetaRule" => AlphaBetaRule(),
71-
"ZPlusRule" => ZPlusRule(),
72-
"ZBoxRule" => ZBoxRule(zero(T), oneunit(T)),
73-
)
74-
75-
layernames = String.(keys(layers))
76-
rulenames = String.(keys(rules))
77-
78-
SUITE["modify layer"] = BenchmarkGroup(rulenames)
79-
SUITE["apply rule"] = BenchmarkGroup(rulenames)
80-
for rname in rulenames
81-
SUITE["modify layer"][rname] = BenchmarkGroup(layernames)
82-
SUITE["apply rule"][rname] = BenchmarkGroup(layernames)
83-
end
84-
85-
for (lname, (layer, aᵏ)) in layers
86-
Rᵏ = similar(aᵏ)
87-
Rᵏ⁺¹ = layer(aᵏ)
88-
for (rname, rule) in rules
89-
modified_layer = modify_layer(rule, layer)
90-
SUITE["modify layer"][rname][lname] = @benchmarkable modify_layer($(rule), $(layer))
91-
SUITE["apply rule"][rname][lname] = @benchmarkable lrp!(
92-
$(Rᵏ), $(rule), $(layer), $(modified_layer), $(aᵏ), $(Rᵏ⁺¹)
93-
)
94-
end
95-
end
3+
# Use PkgJogger.@jog to create the JogExplainableAI module
4+
@jog ExplainableAI
5+
SUITE = JogExplainableAI.suite()

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
34
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
45
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
56
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
67
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
78
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
PkgJogger = "10150987-6cc1-4b76-abee-b1c1cbd91c01"
911
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1012
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
1113
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,8 @@ using JET
3333
@info "Testing analyzers on batches..."
3434
include("test_batches.jl")
3535
end
36+
@testset "Benchmark correctness" begin
37+
@info "Testing whether benchmarks are up-to-date..."
38+
include("test_benchmarks.jl")
39+
end
3640
end

test/test_batches.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using Test
33

44
using Flux
55
using Random
6+
using StableRNGs: StableRNG
67
using Distributions: Laplace
78

89
pseudorand(dims...) = rand(StableRNG(123), Float32, dims...)

test/test_benchmarks.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
using PkgJogger
2+
using ExplainableAI
3+
4+
PkgJogger.@test_benchmarks ExplainableAI

0 commit comments

Comments
 (0)