|
1 |
| -using LuxLib, Reactant, Enzyme |
| 1 | +using LuxLib, Reactant, Enzyme, NNlib |
2 | 2 |
|
3 | 3 | @testset "Fused Dense" begin end
|
4 | 4 |
|
5 | 5 | @testset "Bias Activation" begin end
|
6 | 6 |
|
7 |
| -@testset "Fast Activation" begin end |
| 7 | +@testset "Fast Activation" begin |
| 8 | + # Here we are testing that fast_activation doesn't switch to the faster versions |
| 9 | + sumabs2(f, x) = sum(abs2, fast_activation(f, x)) |
| 10 | + sumabs2!!(f, x) = sum(abs2, fast_activation!!(f, copy(x))) |
| 11 | + |
| 12 | + function ∇sumabs2(f, x) |
| 13 | + dx = Enzyme.make_zero(x) |
| 14 | + Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx)) |
| 15 | + return dx |
| 16 | + end |
| 17 | + |
| 18 | + function ∇sumabs2!!(f, x) |
| 19 | + dx = Enzyme.make_zero(x) |
| 20 | + Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx)) |
| 21 | + return dx |
| 22 | + end |
| 23 | + |
| 24 | + x_act = randn(Float32, 10, 10) |
| 25 | + x_act_ca = Reactant.ConcreteRArray(x_act) |
| 26 | + |
| 27 | + @testset "Activation: $act" for act in ( |
| 28 | + identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2 |
| 29 | + ) |
| 30 | + f_compile = Reactant.compile(sumabs2, (act, x_act)) |
| 31 | + f_compile!! = Reactant.compile(sumabs2!!, (act, x_act)) |
| 32 | + |
| 33 | + y_simple = sumabs2(act, x_act) |
| 34 | + y_simple!! = sumabs2!!(act, x_act) |
| 35 | + y_compile = f_compile(act, x_act_ca) |
| 36 | + y_compile!! = f_compile!!(act, x_act_ca) |
| 37 | + |
| 38 | + @test y_simple ≈ y_compile |
| 39 | + @test y_simple!! ≈ y_compile!! |
| 40 | + |
| 41 | + ∂x_enz = Enzyme.make_zero(x_act) |
| 42 | + Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz)) |
| 43 | + |
| 44 | + ∂x_enz!! = Enzyme.make_zero(x_act) |
| 45 | + Enzyme.autodiff(Reverse, sumabs2!!, Active, Const(act), Duplicated(x_act, ∂x_enz!!)) |
| 46 | + |
| 47 | + ∇sumabs2_compiled = Reactant.compile(∇sumabs2, (act, x_act_ca)) |
| 48 | + ∂x_compile = ∇sumabs2_compiled(act, x_act_ca) |
| 49 | + |
| 50 | + ∇sumabs2!!_compiled = Reactant.compile(∇sumabs2!!, (act, x_act_ca)) |
| 51 | + ∂x_compile!! = ∇sumabs2!!_compiled(act, x_act_ca) |
| 52 | + |
| 53 | + @test ∂x_enz ≈ ∂x_compile broken=(act === gelu) |
| 54 | + @test ∂x_enz!! ≈ ∂x_compile!! broken=(act === gelu) |
| 55 | + end |
| 56 | +end |
8 | 57 |
|
9 | 58 | @testset "Fused Conv" begin end
|
0 commit comments