Skip to content

Commit 62c908d

Browse files
committed
Fix benchmarks
1 parent e7c0d2b commit 62c908d

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

benchmark/benchmarks.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on_CI = haskey(ENV, "GITHUB_ACTIONS")
66

77
include("../test/vgg11.jl")
88
vgg11 = VGG11(; pretrain=false)
9-
model = flatten_model(strip_softmax(vgg19.layers))
9+
model = flatten_model(strip_softmax(vgg11.layers))
1010
img = rand(MersenneTwister(123), Float32, (224, 224, 3, 1))
1111

1212
# Benchmark custom LRP composite
@@ -23,11 +23,13 @@ algs = Dict(
2323
)
2424

2525
# Define benchmark
26+
contruct_analyzer(alg, model) = alg(model) # for use with @benchmarkable macro
27+
2628
SUITE = BenchmarkGroup()
2729
SUITE["VGG"] = BenchmarkGroup([k for k in keys(algs)])
2830
for (name, alg) in algs
2931
SUITE["VGG"][name] = BenchmarkGroup(["construct analyzer", "analyze"])
30-
SUITE["VGG"][name]["construct analyzer"] = @benchmarkable alg($(model))
32+
SUITE["VGG"][name]["construct analyzer"] = @benchmarkable contruct_analyzer($(alg), $(model))
3133

3234
analyzer = alg(model)
3335
SUITE["VGG"][name]["analyze"] = @benchmarkable analyze($(img), $(analyzer))
@@ -58,15 +60,15 @@ rules = Dict(
5860
)
5961
rulenames = [k for k in keys(rules)]
6062

63+
test_rule(rule, layer, aₖ, Rₖ₊₁) = rule(layer, aₖ, Rₖ₊₁) # for use with @benchmarkable macro
64+
6165
for (layername, (layer, aₖ)) in layers
6266
SUITE[layername] = BenchmarkGroup(rulenames)
67+
Rₖ₊₁ = layer(aₖ)
6368

64-
for (rulename, ruletype) in rules
65-
Rₖ₊₁ = layer(aₖ)
69+
for (rulename, rule) in rules
6670
SUITE[layername][rulename] = BenchmarkGroup(["dispatch", "AD fallback"])
67-
SUITE[layername][rulename]["dispatch"] = @benchmarkable rule($layer, $aₖ, $Rₖ₊₁)
68-
SUITE[layername][rulename]["AD fallback"] = @benchmarkable rule(
69-
$TestWrapper(layer), $aₖ, $Rₖ₊₁
70-
)
71+
SUITE[layername][rulename]["dispatch"] = @benchmarkable test_rule($(rule), $(layer), $(aₖ), $(Rₖ₊₁))
72+
SUITE[layername][rulename]["AD fallback"] = @benchmarkable test_rule($(rule), $(TestWrapper(layer)), $(aₖ), $(Rₖ₊₁))
7173
end
7274
end

0 commit comments

Comments
 (0)