|
1 | 1 | using NNlib, Reactant, Enzyme
|
2 | 2 | using Statistics
|
3 | 3 |
|
4 |
| -# @testset "Activation Functions" begin |
5 |
| -# sumabs2(f, x) = sum(abs2, f.(x)) |
| 4 | +@testset "Activation Functions" begin |
| 5 | + sumabs2(f, x) = sum(abs2, f.(x)) |
6 | 6 |
|
7 |
| -# function ∇sumabs2(f, x) |
8 |
| -# dx = Enzyme.make_zero(x) |
9 |
| -# Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx)) |
10 |
| -# return dx |
11 |
| -# end |
| 7 | + function ∇sumabs2(f, x) |
| 8 | + dx = Enzyme.make_zero(x) |
| 9 | + Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx)) |
| 10 | + return dx |
| 11 | + end |
12 | 12 |
|
13 |
| -# x_act = randn(Float32, 10, 10) |
14 |
| -# x_act_ca = Reactant.to_rarray(x_act) |
| 13 | + x_act = randn(Float32, 10, 10) |
| 14 | + x_act_ca = Reactant.to_rarray(x_act) |
15 | 15 |
|
16 |
| -# @testset "Activation: $act" for act in ( |
17 |
| -# identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2, relu6 |
18 |
| -# ) |
19 |
| -# f_compile = Reactant.compile(sumabs2, (act, x_act_ca)) |
| 16 | + @testset "Activation: $act" for act in ( |
| 17 | + identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2, relu6 |
| 18 | + ) |
| 19 | + f_compile = Reactant.compile(sumabs2, (act, x_act_ca)) |
20 | 20 |
|
21 |
| -# y_simple = sumabs2(act, x_act) |
22 |
| -# y_compile = f_compile(act, x_act_ca) |
| 21 | + y_simple = sumabs2(act, x_act) |
| 22 | + y_compile = f_compile(act, x_act_ca) |
23 | 23 |
|
24 |
| -# ∂x_enz = Enzyme.make_zero(x_act) |
25 |
| -# Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz)) |
| 24 | + ∂x_enz = Enzyme.make_zero(x_act) |
| 25 | + Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz)) |
26 | 26 |
|
27 |
| -# ∇sumabs2_compiled = Reactant.compile(∇sumabs2, (act, x_act_ca)) |
| 27 | + ∇sumabs2_compiled = Reactant.compile(∇sumabs2, (act, x_act_ca)) |
28 | 28 |
|
29 |
| -# ∂x_compile = ∇sumabs2_compiled(act, x_act_ca) |
| 29 | + ∂x_compile = ∇sumabs2_compiled(act, x_act_ca) |
30 | 30 |
|
31 |
| -# @test y_simple ≈ y_compile |
32 |
| -# @test ∂x_enz ≈ ∂x_compile |
33 |
| -# end |
34 |
| -# end |
| 31 | + @test y_simple ≈ y_compile |
| 32 | + @test ∂x_enz ≈ ∂x_compile |
| 33 | + end |
| 34 | +end |
35 | 35 |
|
36 |
| -# @testset "Pooling" begin |
37 |
| -# @testset for f in (NNlib.meanpool, NNlib.maxpool) |
38 |
| -# x = randn(Float32, 32, 32, 3, 2) |
39 |
| -# x_reactant = Reactant.to_rarray(x) |
| 36 | +@testset "Pooling" begin |
| 37 | + @testset for f in (NNlib.meanpool, NNlib.maxpool) |
| 38 | + x = randn(Float32, 32, 32, 3, 2) |
| 39 | + x_reactant = Reactant.to_rarray(x) |
40 | 40 |
|
41 |
| -# @testset for window in ((2, 2), (3, 3), (4, 4)), |
42 |
| -# stride in ((1, 1), (2, 2)), |
43 |
| -# padding in ((0, 0), (1, 1), (2, 2), (0, 2), (2, 0)) |
| 41 | + @testset for window in ((2, 2), (3, 3), (4, 4)), |
| 42 | + stride in ((1, 1), (2, 2)), |
| 43 | + padding in ((0, 0), (1, 1), (2, 2), (0, 2), (2, 0)) |
44 | 44 |
|
45 |
| -# pool_dims = PoolDims(x, window; stride, padding) |
| 45 | + pool_dims = PoolDims(x, window; stride, padding) |
46 | 46 |
|
47 |
| -# f_reactant = Reactant.compile(f, (x_reactant, pool_dims)) |
| 47 | + f_reactant = Reactant.compile(f, (x_reactant, pool_dims)) |
48 | 48 |
|
49 |
| -# broken = any(==(2), padding) && f === NNlib.maxpool && window == (2, 2) |
50 |
| -# @test f_reactant(x_reactant, pool_dims) ≈ f(x, pool_dims) broken = broken |
51 |
| -# end |
| 49 | + broken = any(==(2), padding) && f === NNlib.maxpool && window == (2, 2) |
| 50 | + @test f_reactant(x_reactant, pool_dims) ≈ f(x, pool_dims) broken = broken |
| 51 | + end |
52 | 52 |
|
53 |
| -# # TODO: test for gradients |
54 |
| -# end |
55 |
| -# end |
| 53 | + # TODO: test for gradients |
| 54 | + end |
| 55 | +end |
56 | 56 |
|
57 | 57 | function ∇conv_data_filter(x, weight, conv_dims)
|
58 | 58 | dx, dweight = Enzyme.make_zero(x), Enzyme.make_zero(weight)
|
|
75 | 75 | x_reactant = Reactant.to_rarray(x)
|
76 | 76 |
|
77 | 77 | @testset for stride in ((1, 1), (2, 2), (3, 3)),
|
78 |
| - padding in ((0, 0), (1, 1), (2, 2), (0, 2), (2, 0), (0, 1), (1, 0)), |
79 |
| - dilation in ((1, 1), (2, 2), (1, 2), (2, 1)) |
80 |
| - |
81 |
| - @show groups, stride, padding, dilation |
| 78 | + padding in ((0, 0), (2, 2), (0, 2), (2, 0)), |
| 79 | + dilation in ((1, 1), (1, 2)) |
82 | 80 |
|
83 | 81 | conv_dims = DenseConvDims(x, weight; stride, padding, dilation, groups)
|
84 | 82 |
|
|
124 | 122 | end
|
125 | 123 | end
|
126 | 124 |
|
127 |
| -@info "Convolution done" |
128 |
| - |
129 | 125 | @testset "Batched Matrix Multiplication" begin
|
130 | 126 | Reactant.with_config(;
|
131 | 127 | convolution_precision=PrecisionConfig.HIGH,
|
|
157 | 153 | end
|
158 | 154 | end
|
159 | 155 |
|
160 |
| -@info "Batched Matrix Multiplication done" |
161 |
| - |
162 | 156 | @testset "Constant Padding: NNlib.pad_constant" begin
|
163 | 157 | x = rand(Float32, 4, 4)
|
164 | 158 | x_ra = Reactant.to_rarray(x)
|
|
205 | 199 | @test @jit(NNlib.pad_constant(x_ra, (1, 1))) ≈ NNlib.pad_constant(x, (1, 1))
|
206 | 200 | end
|
207 | 201 |
|
208 |
| -@info "Constant Padding done" |
209 |
| - |
210 | 202 | @testset "make_causal_mask" begin
|
211 | 203 | x = rand(2, 10)
|
212 | 204 | x_ra = Reactant.to_rarray(x)
|
|
217 | 209 | @test @jit(causal_mask2(x_ra)) ≈ causal_mask2(x)
|
218 | 210 | end
|
219 | 211 |
|
220 |
| -@info "make_causal_mask done" |
221 |
| - |
222 | 212 | # Adapted from https://github.com/FluxML/NNlib.jl/blob/02138682a4fc5ca019759218be50e59907d4527c/test/testsuite/gather.jl#L5
|
223 | 213 | @testset "NNlib gather" begin
|
224 | 214 | @testset "gather scalar index" begin
|
|
397 | 387 | end
|
398 | 388 | end
|
399 | 389 |
|
400 |
| -@info "Gather done" |
401 |
| - |
402 | 390 | # Adapted from https://github.com/FluxML/NNlib.jl/blob/1468582c4db5f18149cc8fff6fb4633c5debe5c5/test/testsuite/scatter.jl#L108
|
403 | 391 | @testset "NNlib scatter" begin
|
404 | 392 | function test_scatter(dsts, srcs, idxs, res; dims)
|
|
648 | 636 | end
|
649 | 637 | end
|
650 | 638 |
|
651 |
| -@info "Scatter done" |
652 |
| - |
653 | 639 | @testset "∇conv(D = $ndim)" for ndim in 1:3
|
654 | 640 | x_spatial_dim = 4
|
655 | 641 | batch_size = 2
|
|
683 | 669 | end
|
684 | 670 | end
|
685 | 671 |
|
686 |
| -@info "Convolution done" |
687 |
| - |
688 | 672 | @testset "Upsampling" begin
|
689 | 673 | x = randn(Float32, 4, 4, 3, 2)
|
690 | 674 | x_ra = Reactant.to_rarray(x)
|
|
726 | 710 | end
|
727 | 711 | end
|
728 | 712 |
|
729 |
| -@info "Upsampling done" |
730 |
| - |
731 | 713 | @testset "Pixel shuffle" begin
|
732 | 714 | x = [10i + j + channel / 10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1]
|
733 | 715 | x_ra = Reactant.to_rarray(x)
|
|
740 | 722 | @test @jit(NNlib.pixel_shuffle(y_ra, 2)) ≈ NNlib.pixel_shuffle(y, 2)
|
741 | 723 | end
|
742 | 724 |
|
743 |
| -@info "pixel_shuffle done" |
744 |
| - |
745 | 725 | @testset "softmax/logsoftmax reshaped input" begin
|
746 | 726 | x = rand(Float32, 3, 4, 5)
|
747 | 727 | x_ra = reshape(Reactant.to_rarray(x), 12, 5)
|
|
750 | 730 | @test @jit(NNlib.softmax(x_ra)) ≈ NNlib.softmax(x)
|
751 | 731 | @test @jit(NNlib.logsoftmax(x_ra)) ≈ NNlib.logsoftmax(x)
|
752 | 732 | end
|
753 |
| - |
754 |
| -@info "softmax/logsoftmax done" |
0 commit comments