1
1
using Flux
2
- using ExplainableAI: flatten_model, has_output_softmax, check_output_softmax
2
+ using ExplainableAI: flatten_model, has_output_softmax, check_output_softmax, activation
3
3
using ExplainableAI: stabilize_denom, batch_dim_view, drop_batch_index
4
+ using Random
5
+
6
+ pseudorand (dims... ) = rand (MersenneTwister (123 ), Float32, dims... )
7
+
8
+ # Test `activation`
9
+ @test activation (Dense (5 , 2 , gelu)) == gelu
10
+ @test activation (Conv ((5 , 5 ), 3 => 2 , softplus)) == softplus
11
+ @test activation (BatchNorm (5 , selu)) == selu
12
+ @test isnothing (activation (flatten))
4
13
5
14
# flatten_model
6
15
@test flatten_model (Chain (Chain (Chain (abs)), sqrt, Chain (relu))) == Chain (abs, sqrt, relu)
@@ -12,14 +21,31 @@ using ExplainableAI: stabilize_denom, batch_dim_view, drop_batch_index
12
21
@test has_output_softmax (Chain (abs, sqrt, relu, tanh)) == false
13
22
@test has_output_softmax (Chain (Chain (abs), sqrt, Chain (Chain (softmax)))) == true
14
23
@test has_output_softmax (Chain (Chain (abs), Chain (Chain (softmax)), sqrt)) == false
24
+ @test has_output_softmax (Chain (Dense (5 , 5 , softmax), Dense (5 , 5 , softmax))) == true
25
+ @test has_output_softmax (Chain (Dense (5 , 5 , softmax), Dense (5 , 5 , relu))) == false
26
+ @test has_output_softmax (Chain (Dense (5 , 5 , softmax), Chain (Dense (5 , 5 , softmax)))) == true
27
+ @test has_output_softmax (Chain (Dense (5 , 5 , softmax), Chain (Dense (5 , 5 , relu)))) == false
15
28
16
29
# check_output_softmax
17
30
@test_throws ArgumentError check_output_softmax (Chain (abs, sqrt, relu, softmax))
18
31
19
32
# strip_softmax
20
- @test strip_softmax (Chain (Chain (abs), sqrt, Chain (Chain (softmax)))) == Chain (abs, sqrt) # flatten to remove softmax
33
+ d_softmax = Dense (2 , 2 , softmax; init= pseudorand)
34
+ d_softmax2 = Dense (2 , 2 , softmax; init= pseudorand)
35
+ d_relu = Dense (2 , 2 , relu; init= pseudorand)
36
+ d_identity = Dense (2 , 2 ; init= pseudorand)
37
+ # flatten to remove softmax
38
+ m = strip_softmax (Chain (Chain (abs), sqrt, Chain (Chain (softmax))))
39
+ @test m == Chain (abs, sqrt)
40
+ m1 = strip_softmax (Chain (d_relu, Chain (d_softmax)))
41
+ m2 = Chain (d_relu, d_identity)
42
+ x = rand (Float32, 2 , 10 )
43
+ @test typeof (m1) == typeof (m2)
44
+ @test m1 (x) == m2 (x)
45
+ # don't do anything if there is no softmax at the end
21
46
@test strip_softmax (Chain (Chain (abs), Chain (Chain (softmax)), sqrt)) ==
22
- Chain (Chain (abs), Chain (Chain (softmax)), sqrt) # don't do anything if there is no softmax at the end
47
+ Chain (Chain (abs), Chain (Chain (softmax)), sqrt)
48
+ @test strip_softmax (Chain (d_softmax, Chain (d_relu))) == Chain (d_softmax, Chain (d_relu))
23
49
24
50
# stabilize_denom
25
51
A = [1.0 0.0 1.0e-25 ; - 1.0 - 0.0 - 1.0e-25 ]
0 commit comments