|
1 |
| -using BenchmarkTools |
2 |
| -using Flux |
| 1 | +using PkgJogger |
3 | 2 | using RelevancePropagation
|
4 |
| -using RelevancePropagation: lrp!, modify_layer |
5 |
| - |
6 |
| -on_CI = haskey(ENV, "GITHUB_ACTIONS") |
7 |
| - |
8 |
| -T = Float32 |
9 |
| -input_size = (32, 32, 3, 1) |
10 |
| -input = rand(T, input_size) |
11 |
| - |
12 |
| -model = Chain( |
13 |
| - Chain( |
14 |
| - Conv((3, 3), 3 => 8, relu; pad=1), |
15 |
| - Conv((3, 3), 8 => 8, relu; pad=1), |
16 |
| - MaxPool((2, 2)), |
17 |
| - Conv((3, 3), 8 => 16, relu; pad=1), |
18 |
| - Conv((3, 3), 16 => 16, relu; pad=1), |
19 |
| - MaxPool((2, 2)), |
20 |
| - ), |
21 |
| - Chain( |
22 |
| - Flux.flatten, |
23 |
| - Dense(1024 => 512, relu), # 102_764_544 parameters |
24 |
| - Dropout(0.5), |
25 |
| - Dense(512 => 100, relu), |
26 |
| - ), |
27 |
| -) |
28 |
| -Flux.testmode!(model, true) |
29 |
| - |
30 |
| -# Use one representative algorithm of each type |
31 |
| -algs = Dict("LRP" => LRP, "LREpsilonPlusFlat" => model -> LRP(model, EpsilonPlusFlat())) |
32 |
| - |
33 |
| -# Define benchmark |
34 |
| -_alg(alg, model) = alg(model) # for use with @benchmarkable macro |
35 |
| - |
36 |
| -SUITE = BenchmarkGroup() |
37 |
| -SUITE["CNN"] = BenchmarkGroup([k for k in keys(algs)]) |
38 |
| -for (name, alg) in algs |
39 |
| - analyzer = alg(model) |
40 |
| - SUITE["CNN"][name] = BenchmarkGroup(["construct analyzer", "analyze"]) |
41 |
| - SUITE["CNN"][name]["construct analyzer"] = @benchmarkable _alg($(alg), $(model)) |
42 |
| - SUITE["CNN"][name]["analyze"] = @benchmarkable analyze($(input), $(analyzer)) |
43 |
| -end |
44 |
| - |
45 |
| -# generate input for conv layers |
46 |
| -insize = (32, 32, 3, 1) |
47 |
| -in_dense = 64 |
48 |
| -out_dense = 10 |
49 |
| -aᵏ = rand(T, insize) |
50 |
| - |
51 |
| -layers = Dict( |
52 |
| - "Conv" => (Conv((3, 3), 3 => 2), aᵏ), |
53 |
| - "Dense" => (Dense(in_dense, out_dense, relu), randn(T, in_dense, 1)), |
54 |
| -) |
55 |
| -rules = Dict( |
56 |
| - "ZeroRule" => ZeroRule(), |
57 |
| - "EpsilonRule" => EpsilonRule(), |
58 |
| - "GammaRule" => GammaRule(), |
59 |
| - "WSquareRule" => WSquareRule(), |
60 |
| - "FlatRule" => FlatRule(), |
61 |
| - "AlphaBetaRule" => AlphaBetaRule(), |
62 |
| - "ZPlusRule" => ZPlusRule(), |
63 |
| - "ZBoxRule" => ZBoxRule(zero(T), oneunit(T)), |
64 |
| -) |
65 |
| - |
66 |
| -layernames = String.(keys(layers)) |
67 |
| -rulenames = String.(keys(rules)) |
68 |
| - |
69 |
| -SUITE["modify layer"] = BenchmarkGroup(rulenames) |
70 |
| -SUITE["apply rule"] = BenchmarkGroup(rulenames) |
71 |
| -for rname in rulenames |
72 |
| - SUITE["modify layer"][rname] = BenchmarkGroup(layernames) |
73 |
| - SUITE["apply rule"][rname] = BenchmarkGroup(layernames) |
74 |
| -end |
75 |
| - |
76 |
| -for (lname, (layer, aᵏ)) in layers |
77 |
| - Rᵏ = similar(aᵏ) |
78 |
| - Rᵏ⁺¹ = layer(aᵏ) |
79 |
| - for (rname, rule) in rules |
80 |
| - modified_layer = modify_layer(rule, layer) |
81 |
| - SUITE["modify layer"][rname][lname] = @benchmarkable modify_layer($(rule), $(layer)) |
82 |
| - SUITE["apply rule"][rname][lname] = @benchmarkable lrp!( |
83 |
| - $(Rᵏ), $(rule), $(layer), $(modified_layer), $(aᵏ), $(Rᵏ⁺¹) |
84 |
| - ) |
85 |
| - end |
86 |
| -end |
| 3 | +# Use PkgJogger.@jog to create the JogRelevancePropagation module |
| 4 | +@jog RelevancePropagation |
| 5 | +SUITE = JogRelevancePropagation.suite() |
0 commit comments