|
1 | 1 | using LuxLib, Reactant, Enzyme, NNlib
|
2 | 2 |
|
3 |
| -@testset "Fused Dense" begin end |
| 3 | +@testset "Fused Dense" begin |
| 4 | + sumabs2fuseddense(act, weight, x, bias) = |
| 5 | + sum(abs2, fused_dense_bias_activation(act, weight, x, bias)) |
| 6 | + |
| 7 | + function ∇fuseddense(act, weight, x, bias) |
| 8 | + dw = Enzyme.make_zero(weight) |
| 9 | + dx = Enzyme.make_zero(x) |
| 10 | + db = bias === nothing ? nothing : Enzyme.make_zero(bias) |
| 11 | + b_dup = bias === nothing ? Const(bias) : Duplicated(bias, db) |
| 12 | + Enzyme.autodiff( |
| 13 | + Reverse, |
| 14 | + sumabs2fuseddense, |
| 15 | + Active, |
| 16 | + Const(act), |
| 17 | + Duplicated(weight, dw), |
| 18 | + Duplicated(x, dx), |
| 19 | + b_dup, |
| 20 | + ) |
| 21 | + return dw, dx, db |
| 22 | + end |
| 23 | + |
| 24 | + @testset for act in (identity, relu, sigmoid, tanh, gelu), has_bias in (true, false) |
| 25 | + weight = randn(Float32, 9, 10) |
| 26 | + x = randn(Float32, 10, 12) |
| 27 | + bias = has_bias ? randn(Float32, 9) : nothing |
| 28 | + |
| 29 | + weight_ra = Reactant.ConcreteRArray(weight) |
| 30 | + x_ra = Reactant.ConcreteRArray(x) |
| 31 | + bias_ra = Reactant.to_rarray(bias) |
| 32 | + |
| 33 | + f_compile = Reactant.compile( |
| 34 | + fused_dense_bias_activation, (act, weight_ra, x_ra, bias_ra) |
| 35 | + ) |
| 36 | + |
| 37 | + y_res = fused_dense_bias_activation(act, weight, x, bias) |
| 38 | + y_compile = f_compile(act, weight_ra, x_ra, bias_ra) |
| 39 | + |
| 40 | + @test y_res ≈ y_compile broken = (act === gelu) |
| 41 | + |
| 42 | + @testset "Enzyme: fused_dense_bias_activation" begin |
| 43 | + dw, dx, db = ∇fuseddense(act, weight, x, bias) |
| 44 | + ∇fuseddense_compiled = Reactant.compile( |
| 45 | + ∇fuseddense, (act, weight_ra, x_ra, bias_ra) |
| 46 | + ) |
| 47 | + dw_compile, dx_compile, db_compile = ∇fuseddense_compiled( |
| 48 | + act, weight_ra, x_ra, bias_ra |
| 49 | + ) |
| 50 | + |
| 51 | + @test dw ≈ dw_compile broken = (act === gelu) |
| 52 | + @test dx ≈ dx_compile broken = (act === gelu) |
| 53 | + has_bias && @test db ≈ db_compile broken = (act === gelu) |
| 54 | + end |
| 55 | + end |
| 56 | +end |
4 | 57 |
|
5 | 58 | @testset "Bias Activation" begin
|
6 | 59 | biasact(act, x, b) = bias_activation(act, x, b)
|
@@ -54,7 +107,6 @@ using LuxLib, Reactant, Enzyme, NNlib
|
54 | 107 | @test y_simple ≈ y_compile broken = (act === gelu)
|
55 | 108 | @test y_simple!! ≈ y_compile!! broken = (act === gelu)
|
56 | 109 |
|
57 |
| - # FIXME: Seems broken currently |
58 | 110 | @testset "Enzyme: bias_activation" begin
|
59 | 111 | ∂x_enz, ∂b_enz = ∇biasact(act, x, b)
|
60 | 112 | ∇biasact_compiled = Reactant.compile(∇biasact, (act, x_ra, b_ra))
|
|
106 | 158 | y_compile = f_compile(act, x_act_ca)
|
107 | 159 | y_compile!! = f_compile!!(act, x_act_ca)
|
108 | 160 |
|
109 |
| - @test y_simple ≈ y_compile |
110 |
| - @test y_simple!! ≈ y_compile!! |
| 161 | + @test y_simple ≈ y_compile broken = (act === gelu) |
| 162 | + @test y_simple!! ≈ y_compile!! broken = (act === gelu) |
111 | 163 |
|
112 | 164 | ∂x_enz = Enzyme.make_zero(x_act)
|
113 | 165 | Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz))
|
|
0 commit comments