|
1 | 1 | using NNlib, Reactant, Enzyme
|
2 | 2 | using Statistics
|
3 | 3 |
|
4 |
| -@testset "Activation Functions" begin |
5 |
| - sumabs2(f, x) = sum(abs2, f.(x)) |
| 4 | +# @testset "Activation Functions" begin |
| 5 | +# sumabs2(f, x) = sum(abs2, f.(x)) |
6 | 6 |
|
7 |
| - function ∇sumabs2(f, x) |
8 |
| - dx = Enzyme.make_zero(x) |
9 |
| - Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx)) |
10 |
| - return dx |
11 |
| - end |
| 7 | +# function ∇sumabs2(f, x) |
| 8 | +# dx = Enzyme.make_zero(x) |
| 9 | +# Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx)) |
| 10 | +# return dx |
| 11 | +# end |
12 | 12 |
|
13 |
| - x_act = randn(Float32, 10, 10) |
14 |
| - x_act_ca = Reactant.to_rarray(x_act) |
| 13 | +# x_act = randn(Float32, 10, 10) |
| 14 | +# x_act_ca = Reactant.to_rarray(x_act) |
15 | 15 |
|
16 |
| - @testset "Activation: $act" for act in ( |
17 |
| - identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2, relu6 |
18 |
| - ) |
19 |
| - f_compile = Reactant.compile(sumabs2, (act, x_act_ca)) |
| 16 | +# @testset "Activation: $act" for act in ( |
| 17 | +# identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2, relu6 |
| 18 | +# ) |
| 19 | +# f_compile = Reactant.compile(sumabs2, (act, x_act_ca)) |
20 | 20 |
|
21 |
| - y_simple = sumabs2(act, x_act) |
22 |
| - y_compile = f_compile(act, x_act_ca) |
| 21 | +# y_simple = sumabs2(act, x_act) |
| 22 | +# y_compile = f_compile(act, x_act_ca) |
23 | 23 |
|
24 |
| - ∂x_enz = Enzyme.make_zero(x_act) |
25 |
| - Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz)) |
| 24 | +# ∂x_enz = Enzyme.make_zero(x_act) |
| 25 | +# Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz)) |
26 | 26 |
|
27 |
| - ∇sumabs2_compiled = Reactant.compile(∇sumabs2, (act, x_act_ca)) |
| 27 | +# ∇sumabs2_compiled = Reactant.compile(∇sumabs2, (act, x_act_ca)) |
28 | 28 |
|
29 |
| - ∂x_compile = ∇sumabs2_compiled(act, x_act_ca) |
| 29 | +# ∂x_compile = ∇sumabs2_compiled(act, x_act_ca) |
30 | 30 |
|
31 |
| - @test y_simple ≈ y_compile |
32 |
| - @test ∂x_enz ≈ ∂x_compile |
33 |
| - end |
34 |
| -end |
| 31 | +# @test y_simple ≈ y_compile |
| 32 | +# @test ∂x_enz ≈ ∂x_compile |
| 33 | +# end |
| 34 | +# end |
35 | 35 |
|
36 |
| -@testset "Pooling" begin |
37 |
| - @testset for f in (NNlib.meanpool, NNlib.maxpool) |
38 |
| - x = randn(Float32, 32, 32, 3, 2) |
39 |
| - x_reactant = Reactant.to_rarray(x) |
| 36 | +# @testset "Pooling" begin |
| 37 | +# @testset for f in (NNlib.meanpool, NNlib.maxpool) |
| 38 | +# x = randn(Float32, 32, 32, 3, 2) |
| 39 | +# x_reactant = Reactant.to_rarray(x) |
40 | 40 |
|
41 |
| - @testset for window in ((2, 2), (3, 3), (4, 4)), |
42 |
| - stride in ((1, 1), (2, 2)), |
43 |
| - padding in ((0, 0), (1, 1), (2, 2), (0, 2), (2, 0)) |
| 41 | +# @testset for window in ((2, 2), (3, 3), (4, 4)), |
| 42 | +# stride in ((1, 1), (2, 2)), |
| 43 | +# padding in ((0, 0), (1, 1), (2, 2), (0, 2), (2, 0)) |
44 | 44 |
|
45 |
| - pool_dims = PoolDims(x, window; stride, padding) |
| 45 | +# pool_dims = PoolDims(x, window; stride, padding) |
46 | 46 |
|
47 |
| - f_reactant = Reactant.compile(f, (x_reactant, pool_dims)) |
| 47 | +# f_reactant = Reactant.compile(f, (x_reactant, pool_dims)) |
48 | 48 |
|
49 |
| - broken = any(==(2), padding) && f === NNlib.maxpool && window == (2, 2) |
50 |
| - @test f_reactant(x_reactant, pool_dims) ≈ f(x, pool_dims) broken = broken |
51 |
| - end |
| 49 | +# broken = any(==(2), padding) && f === NNlib.maxpool && window == (2, 2) |
| 50 | +# @test f_reactant(x_reactant, pool_dims) ≈ f(x, pool_dims) broken = broken |
| 51 | +# end |
52 | 52 |
|
53 |
| - # TODO: test for gradients |
54 |
| - end |
55 |
| -end |
| 53 | +# # TODO: test for gradients |
| 54 | +# end |
| 55 | +# end |
56 | 56 |
|
57 | 57 | function ∇conv_data_filter(x, weight, conv_dims)
|
58 | 58 | dx, dweight = Enzyme.make_zero(x), Enzyme.make_zero(weight)
|
|
88 | 88 | dy = ones(Float32, output_size)
|
89 | 89 | dy_reactant = Reactant.to_rarray(dy)
|
90 | 90 |
|
91 |
| - Reactant.with_config(; convolution_precision=PrecisionConfig.HIGHEST) do |
| 91 | + Reactant.with_config(; convolution_precision=PrecisionConfig.HIGH) do |
92 | 92 | @test @jit(NNlib.conv(x_reactant, weight_reactant, conv_dims)) ≈
|
93 | 93 | NNlib.conv(x, weight, conv_dims)
|
94 | 94 |
|
@@ -122,10 +122,12 @@ end
|
122 | 122 | end
|
123 | 123 | end
|
124 | 124 |
|
| 125 | +@info "Convolution done" |
| 126 | + |
125 | 127 | @testset "Batched Matrix Multiplication" begin
|
126 | 128 | Reactant.with_config(;
|
127 |
| - convolution_precision=PrecisionConfig.HIGHEST, |
128 |
| - dot_general_precision=PrecisionConfig.HIGHEST, |
| 129 | + convolution_precision=PrecisionConfig.HIGH, |
| 130 | + dot_general_precision=PrecisionConfig.HIGH, |
129 | 131 | ) do
|
130 | 132 | x = rand(Float32, 4, 3, 5)
|
131 | 133 | y = rand(Float32, 3, 2, 5)
|
|
153 | 155 | end
|
154 | 156 | end
|
155 | 157 |
|
| 158 | +@info "Batched Matrix Multiplication done" |
| 159 | + |
156 | 160 | @testset "Constant Padding: NNlib.pad_constant" begin
|
157 | 161 | x = rand(Float32, 4, 4)
|
158 | 162 | x_ra = Reactant.to_rarray(x)
|
|
199 | 203 | @test @jit(NNlib.pad_constant(x_ra, (1, 1))) ≈ NNlib.pad_constant(x, (1, 1))
|
200 | 204 | end
|
201 | 205 |
|
| 206 | +@info "Constant Padding done" |
| 207 | + |
202 | 208 | @testset "make_causal_mask" begin
|
203 | 209 | x = rand(2, 10)
|
204 | 210 | x_ra = Reactant.to_rarray(x)
|
|
209 | 215 | @test @jit(causal_mask2(x_ra)) ≈ causal_mask2(x)
|
210 | 216 | end
|
211 | 217 |
|
| 218 | +@info "make_causal_mask done" |
| 219 | + |
212 | 220 | # Adapted from https://github.com/FluxML/NNlib.jl/blob/02138682a4fc5ca019759218be50e59907d4527c/test/testsuite/gather.jl#L5
|
213 | 221 | @testset "NNlib gather" begin
|
214 | 222 | @testset "gather scalar index" begin
|
|
387 | 395 | end
|
388 | 396 | end
|
389 | 397 |
|
| 398 | +@info "Gather done" |
| 399 | + |
390 | 400 | # Adapted from https://github.com/FluxML/NNlib.jl/blob/1468582c4db5f18149cc8fff6fb4633c5debe5c5/test/testsuite/scatter.jl#L108
|
391 | 401 | @testset "NNlib scatter" begin
|
392 | 402 | function test_scatter(dsts, srcs, idxs, res; dims)
|
|
636 | 646 | end
|
637 | 647 | end
|
638 | 648 |
|
| 649 | +@info "Scatter done" |
| 650 | + |
639 | 651 | @testset "∇conv(D = $ndim)" for ndim in 1:3
|
640 | 652 | x_spatial_dim = 4
|
641 | 653 | batch_size = 2
|
|
654 | 666 | ) in Iterators.product(
|
655 | 667 | (0, 2), (1, 2), (1,), (1,)
|
656 | 668 | )
|
657 |
| - Reactant.with_config(; convolution_precision=PrecisionConfig.HIGHEST) do |
| 669 | + Reactant.with_config(; convolution_precision=PrecisionConfig.HIGH) do |
658 | 670 | conv_dims = NNlib.DenseConvDims(x, w; padding, stride, dilation, groups)
|
659 | 671 |
|
660 | 672 | output_size = (NNlib.output_size(conv_dims)..., n_out_features, batch_size)
|
|
669 | 681 | end
|
670 | 682 | end
|
671 | 683 |
|
| 684 | +@info "Convolution done" |
| 685 | + |
672 | 686 | @testset "Upsampling" begin
|
673 | 687 | x = randn(Float32, 4, 4, 3, 2)
|
674 | 688 | x_ra = Reactant.to_rarray(x)
|
@@ -710,18 +724,22 @@ end
|
710 | 724 | end
|
711 | 725 | end
|
712 | 726 |
|
| 727 | +@info "Upsampling done" |
| 728 | + |
713 | 729 | @testset "Pixel shuffle" begin
|
714 |
| - x = Int32[10i + j + channel / 10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1] |
| 730 | + x = [10i + j + channel / 10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1] |
715 | 731 | x_ra = Reactant.to_rarray(x)
|
716 | 732 |
|
717 | 733 | @test @jit(NNlib.pixel_shuffle(x_ra, 2)) ≈ NNlib.pixel_shuffle(x, 2)
|
718 | 734 |
|
719 |
| - y = Int32[i + channel / 10 for i in 1:3, channel in 1:6, batch in 1:1] |
| 735 | + y = [i + channel / 10 for i in 1:3, channel in 1:6, batch in 1:1] |
720 | 736 | y_ra = Reactant.to_rarray(y)
|
721 | 737 |
|
722 | 738 | @test @jit(NNlib.pixel_shuffle(y_ra, 2)) ≈ NNlib.pixel_shuffle(y, 2)
|
723 | 739 | end
|
724 | 740 |
|
| 741 | +@info "pixel_shuffle done" |
| 742 | + |
725 | 743 | @testset "softmax/logsoftmax reshaped input" begin
|
726 | 744 | x = rand(Float32, 3, 4, 5)
|
727 | 745 | x_ra = reshape(Reactant.to_rarray(x), 12, 5)
|
|
730 | 748 | @test @jit(NNlib.softmax(x_ra)) ≈ NNlib.softmax(x)
|
731 | 749 | @test @jit(NNlib.logsoftmax(x_ra)) ≈ NNlib.logsoftmax(x)
|
732 | 750 | end
|
| 751 | + |
| 752 | +@info "softmax/logsoftmax done" |
0 commit comments