Skip to content

Commit c19c823

Browse files
authored
Test benchmarks with PkgJogger (#20)
* Test benchmarks with PkgJogger * Update benchmark workflow * Fix benchmark baseline
1 parent e82c566 commit c19c823

File tree

8 files changed

+112
-100
lines changed

8 files changed

+112
-100
lines changed

.github/workflows/Benchmark.yml

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,21 @@ on:
66
jobs:
77
Benchmark:
88
runs-on: ubuntu-latest
9+
permissions:
10+
pull-requests: write
11+
actions: write # needed to allow julia-actions/cache to proactively delete old caches that it has created
12+
contents: read
913
if: contains(github.event.pull_request.labels.*.name, 'run benchmark')
1014
steps:
1115
- uses: actions/checkout@v4
12-
- uses: julia-actions/setup-julia@latest
13-
- name: Cache artifacts
14-
uses: actions/cache@v3
15-
env:
16-
cache-name: cache-artifacts
16+
- uses: julia-actions/setup-julia@v2
1717
with:
18-
path: ~/.julia/artifacts
19-
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
20-
restore-keys: |
21-
${{ runner.os }}-test-${{ env.cache-name }}-
22-
${{ runner.os }}-test-
23-
${{ runner.os }}-
18+
version: '1'
19+
- uses: julia-actions/cache@v2
2420
- name: Install dependencies
25-
run: julia -e 'using Pkg; pkg"add JSON PkgBenchmark [email protected]"'
21+
run: julia --color=yes -e 'using Pkg; pkg"add JSON PkgBenchmark [email protected]"'
2622
- name: Run benchmarks
27-
run: julia benchmark/run_benchmarks.jl
23+
run: julia --color=yes benchmark/run_benchmarks.jl
2824
env:
2925
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
3026

benchmark/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
[deps]
22
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
33
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
4-
RelevancePropagation = "0be6dd02-ae9e-43eb-b318-c6e81d6890d8"
54
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
5+
PkgJogger = "10150987-6cc1-4b76-abee-b1c1cbd91c01"
6+
RelevancePropagation = "0be6dd02-ae9e-43eb-b318-c6e81d6890d8"
67

78
[compat]
89
BenchmarkTools = "1"

benchmark/bench_jogger.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
using BenchmarkTools
2+
using Flux
3+
using RelevancePropagation
4+
using RelevancePropagation: lrp!, modify_layer
5+
6+
on_CI = haskey(ENV, "GITHUB_ACTIONS")
7+
8+
T = Float32
9+
input_size = (32, 32, 3, 1)
10+
input = rand(T, input_size)
11+
12+
model = Chain(
13+
Chain(
14+
Conv((3, 3), 3 => 8, relu; pad=1),
15+
Conv((3, 3), 8 => 8, relu; pad=1),
16+
MaxPool((2, 2)),
17+
Conv((3, 3), 8 => 16, relu; pad=1),
18+
Conv((3, 3), 16 => 16, relu; pad=1),
19+
MaxPool((2, 2)),
20+
),
21+
Chain(
22+
Flux.flatten,
23+
Dense(1024 => 512, relu), # 102_764_544 parameters
24+
Dropout(0.5),
25+
Dense(512 => 100, relu),
26+
),
27+
)
28+
Flux.testmode!(model, true)
29+
30+
# Use one representative algorithm of each type
31+
algs = Dict("LRP" => LRP, "LREpsilonPlusFlat" => model -> LRP(model, EpsilonPlusFlat()))
32+
33+
# Define benchmark
34+
_alg(alg, model) = alg(model) # for use with @benchmarkable macro
35+
36+
suite = BenchmarkGroup()
37+
suite["CNN"] = BenchmarkGroup([k for k in keys(algs)])
38+
for (name, alg) in algs
39+
analyzer = alg(model)
40+
suite["CNN"][name] = BenchmarkGroup(["construct analyzer", "analyze"])
41+
suite["CNN"][name]["construct analyzer"] = @benchmarkable _alg($(alg), $(model))
42+
suite["CNN"][name]["analyze"] = @benchmarkable analyze($(input), $(analyzer))
43+
end
44+
45+
# generate input for conv layers
46+
insize = (32, 32, 3, 1)
47+
in_dense = 64
48+
out_dense = 10
49+
aᵏ = rand(T, insize)
50+
51+
layers = Dict(
52+
"Conv" => (Conv((3, 3), 3 => 2), aᵏ),
53+
"Dense" => (Dense(in_dense, out_dense, relu), randn(T, in_dense, 1)),
54+
)
55+
rules = Dict(
56+
"ZeroRule" => ZeroRule(),
57+
"EpsilonRule" => EpsilonRule(),
58+
"GammaRule" => GammaRule(),
59+
"WSquareRule" => WSquareRule(),
60+
"FlatRule" => FlatRule(),
61+
"AlphaBetaRule" => AlphaBetaRule(),
62+
"ZPlusRule" => ZPlusRule(),
63+
"ZBoxRule" => ZBoxRule(zero(T), oneunit(T)),
64+
)
65+
66+
layernames = String.(keys(layers))
67+
rulenames = String.(keys(rules))
68+
69+
suite["modify layer"] = BenchmarkGroup(rulenames)
70+
suite["apply rule"] = BenchmarkGroup(rulenames)
71+
for rname in rulenames
72+
suite["modify layer"][rname] = BenchmarkGroup(layernames)
73+
suite["apply rule"][rname] = BenchmarkGroup(layernames)
74+
end
75+
76+
for (lname, (layer, aᵏ)) in layers
77+
Rᵏ = similar(aᵏ)
78+
Rᵏ⁺¹ = layer(aᵏ)
79+
for (rname, rule) in rules
80+
modified_layer = modify_layer(rule, layer)
81+
suite["modify layer"][rname][lname] = @benchmarkable modify_layer($(rule), $(layer))
82+
suite["apply rule"][rname][lname] = @benchmarkable lrp!(
83+
$(Rᵏ), $(rule), $(layer), $(modified_layer), $(aᵏ), $(Rᵏ⁺¹)
84+
)
85+
end
86+
end

benchmark/benchmarks.jl

Lines changed: 4 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,86 +1,5 @@
1-
using BenchmarkTools
2-
using Flux
1+
using PkgJogger
32
using RelevancePropagation
4-
using RelevancePropagation: lrp!, modify_layer
5-
6-
on_CI = haskey(ENV, "GITHUB_ACTIONS")
7-
8-
T = Float32
9-
input_size = (32, 32, 3, 1)
10-
input = rand(T, input_size)
11-
12-
model = Chain(
13-
Chain(
14-
Conv((3, 3), 3 => 8, relu; pad=1),
15-
Conv((3, 3), 8 => 8, relu; pad=1),
16-
MaxPool((2, 2)),
17-
Conv((3, 3), 8 => 16, relu; pad=1),
18-
Conv((3, 3), 16 => 16, relu; pad=1),
19-
MaxPool((2, 2)),
20-
),
21-
Chain(
22-
Flux.flatten,
23-
Dense(1024 => 512, relu), # 102_764_544 parameters
24-
Dropout(0.5),
25-
Dense(512 => 100, relu),
26-
),
27-
)
28-
Flux.testmode!(model, true)
29-
30-
# Use one representative algorithm of each type
31-
algs = Dict("LRP" => LRP, "LREpsilonPlusFlat" => model -> LRP(model, EpsilonPlusFlat()))
32-
33-
# Define benchmark
34-
_alg(alg, model) = alg(model) # for use with @benchmarkable macro
35-
36-
SUITE = BenchmarkGroup()
37-
SUITE["CNN"] = BenchmarkGroup([k for k in keys(algs)])
38-
for (name, alg) in algs
39-
analyzer = alg(model)
40-
SUITE["CNN"][name] = BenchmarkGroup(["construct analyzer", "analyze"])
41-
SUITE["CNN"][name]["construct analyzer"] = @benchmarkable _alg($(alg), $(model))
42-
SUITE["CNN"][name]["analyze"] = @benchmarkable analyze($(input), $(analyzer))
43-
end
44-
45-
# generate input for conv layers
46-
insize = (32, 32, 3, 1)
47-
in_dense = 64
48-
out_dense = 10
49-
aᵏ = rand(T, insize)
50-
51-
layers = Dict(
52-
"Conv" => (Conv((3, 3), 3 => 2), aᵏ),
53-
"Dense" => (Dense(in_dense, out_dense, relu), randn(T, in_dense, 1)),
54-
)
55-
rules = Dict(
56-
"ZeroRule" => ZeroRule(),
57-
"EpsilonRule" => EpsilonRule(),
58-
"GammaRule" => GammaRule(),
59-
"WSquareRule" => WSquareRule(),
60-
"FlatRule" => FlatRule(),
61-
"AlphaBetaRule" => AlphaBetaRule(),
62-
"ZPlusRule" => ZPlusRule(),
63-
"ZBoxRule" => ZBoxRule(zero(T), oneunit(T)),
64-
)
65-
66-
layernames = String.(keys(layers))
67-
rulenames = String.(keys(rules))
68-
69-
SUITE["modify layer"] = BenchmarkGroup(rulenames)
70-
SUITE["apply rule"] = BenchmarkGroup(rulenames)
71-
for rname in rulenames
72-
SUITE["modify layer"][rname] = BenchmarkGroup(layernames)
73-
SUITE["apply rule"][rname] = BenchmarkGroup(layernames)
74-
end
75-
76-
for (lname, (layer, aᵏ)) in layers
77-
Rᵏ = similar(aᵏ)
78-
Rᵏ⁺¹ = layer(aᵏ)
79-
for (rname, rule) in rules
80-
modified_layer = modify_layer(rule, layer)
81-
SUITE["modify layer"][rname][lname] = @benchmarkable modify_layer($(rule), $(layer))
82-
SUITE["apply rule"][rname][lname] = @benchmarkable lrp!(
83-
$(Rᵏ), $(rule), $(layer), $(modified_layer), $(aᵏ), $(Rᵏ⁺¹)
84-
)
85-
end
86-
end
3+
# Use PkgJogger.@jog to create the JogRelevancePropagation module
4+
@jog RelevancePropagation
5+
SUITE = JogRelevancePropagation.suite()

benchmark/run_benchmarks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66
using BenchmarkCI
77
on_CI = haskey(ENV, "GITHUB_ACTIONS")
88

9-
BenchmarkCI.judge()
9+
BenchmarkCI.judge(; baseline="origin/main")
1010
on_CI ? BenchmarkCI.postjudge() : BenchmarkCI.displayjudgement()

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
34
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
45
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
56
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
67
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
89
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
10+
PkgJogger = "10150987-6cc1-4b76-abee-b1c1cbd91c01"
911
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1012
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
1113
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,8 @@ using Aqua
5454
@info "Testing analyzers on batches..."
5555
include("test_batches.jl")
5656
end
57+
@testset "Benchmark correctness" begin
58+
@info "Testing whether benchmarks are up-to-date..."
59+
include("test_benchmarks.jl")
60+
end
5761
end

test/test_benchmarks.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
using PkgJogger
2+
using RelevancePropagation
3+
4+
PkgJogger.@test_benchmarks RelevancePropagation

0 commit comments

Comments
 (0)