Skip to content

Commit 534d3e4

Browse files
committed
test: reduce test configs
1 parent 1f36df4 commit 534d3e4

File tree

3 files changed

+48
-74
lines changed

3 files changed

+48
-74
lines changed

test/nn/luxlib.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,7 @@ end
179179
end
180180

181181
@testset "Fused Conv" begin
182-
@testset for groups in (1, 2, 4),
183-
has_bias in (true, false),
184-
act in (identity, relu, sigmoid, tanh, gelu)
185-
182+
@testset for groups in (1, 2), has_bias in (true, false), act in (identity, relu, tanh)
186183
weight = randn(Float32, 4, 4, 8 ÷ groups, 4)
187184
x = randn(Float32, 16, 16, 8, 2)
188185
bias = has_bias ? randn(Float32, 4) : nothing
@@ -191,9 +188,9 @@ end
191188
x_reactant = Reactant.to_rarray(x)
192189
bias_reactant = Reactant.to_rarray(bias)
193190

194-
@testset for stride in ((1, 1), (2, 2), (3, 3)),
195-
padding in ((0, 0), (1, 1), (2, 2), (0, 2), (2, 0), (0, 1), (1, 0)),
196-
dilation in ((1, 1), (2, 2), (1, 2), (2, 1))
191+
@testset for stride in ((1, 1), (3, 3)),
192+
padding in ((0, 0), (2, 2), (2, 0)),
193+
dilation in ((1, 1), (1, 2))
197194

198195
conv_dims = DenseConvDims(x, weight; stride, padding, dilation, groups)
199196

test/nn/nnlib.jl

Lines changed: 40 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,58 @@
11
using NNlib, Reactant, Enzyme
22
using Statistics
33

4-
# @testset "Activation Functions" begin
5-
# sumabs2(f, x) = sum(abs2, f.(x))
4+
@testset "Activation Functions" begin
5+
sumabs2(f, x) = sum(abs2, f.(x))
66

7-
# function ∇sumabs2(f, x)
8-
# dx = Enzyme.make_zero(x)
9-
# Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx))
10-
# return dx
11-
# end
7+
function ∇sumabs2(f, x)
8+
dx = Enzyme.make_zero(x)
9+
Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx))
10+
return dx
11+
end
1212

13-
# x_act = randn(Float32, 10, 10)
14-
# x_act_ca = Reactant.to_rarray(x_act)
13+
x_act = randn(Float32, 10, 10)
14+
x_act_ca = Reactant.to_rarray(x_act)
1515

16-
# @testset "Activation: $act" for act in (
17-
# identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2, relu6
18-
# )
19-
# f_compile = Reactant.compile(sumabs2, (act, x_act_ca))
16+
@testset "Activation: $act" for act in (
17+
identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2, relu6
18+
)
19+
f_compile = Reactant.compile(sumabs2, (act, x_act_ca))
2020

21-
# y_simple = sumabs2(act, x_act)
22-
# y_compile = f_compile(act, x_act_ca)
21+
y_simple = sumabs2(act, x_act)
22+
y_compile = f_compile(act, x_act_ca)
2323

24-
# ∂x_enz = Enzyme.make_zero(x_act)
25-
# Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz))
24+
∂x_enz = Enzyme.make_zero(x_act)
25+
Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz))
2626

27-
# ∇sumabs2_compiled = Reactant.compile(∇sumabs2, (act, x_act_ca))
27+
∇sumabs2_compiled = Reactant.compile(∇sumabs2, (act, x_act_ca))
2828

29-
# ∂x_compile = ∇sumabs2_compiled(act, x_act_ca)
29+
∂x_compile = ∇sumabs2_compiled(act, x_act_ca)
3030

31-
# @test y_simple ≈ y_compile
32-
# @test ∂x_enz ≈ ∂x_compile
33-
# end
34-
# end
31+
@test y_simple y_compile
32+
@test ∂x_enz ∂x_compile
33+
end
34+
end
3535

36-
# @testset "Pooling" begin
37-
# @testset for f in (NNlib.meanpool, NNlib.maxpool)
38-
# x = randn(Float32, 32, 32, 3, 2)
39-
# x_reactant = Reactant.to_rarray(x)
36+
@testset "Pooling" begin
37+
@testset for f in (NNlib.meanpool, NNlib.maxpool)
38+
x = randn(Float32, 32, 32, 3, 2)
39+
x_reactant = Reactant.to_rarray(x)
4040

41-
# @testset for window in ((2, 2), (3, 3), (4, 4)),
42-
# stride in ((1, 1), (2, 2)),
43-
# padding in ((0, 0), (1, 1), (2, 2), (0, 2), (2, 0))
41+
@testset for window in ((2, 2), (3, 3), (4, 4)),
42+
stride in ((1, 1), (2, 2)),
43+
padding in ((0, 0), (1, 1), (2, 2), (0, 2), (2, 0))
4444

45-
# pool_dims = PoolDims(x, window; stride, padding)
45+
pool_dims = PoolDims(x, window; stride, padding)
4646

47-
# f_reactant = Reactant.compile(f, (x_reactant, pool_dims))
47+
f_reactant = Reactant.compile(f, (x_reactant, pool_dims))
4848

49-
# broken = any(==(2), padding) && f === NNlib.maxpool && window == (2, 2)
50-
# @test f_reactant(x_reactant, pool_dims) ≈ f(x, pool_dims) broken = broken
51-
# end
49+
broken = any(==(2), padding) && f === NNlib.maxpool && window == (2, 2)
50+
@test f_reactant(x_reactant, pool_dims) f(x, pool_dims) broken = broken
51+
end
5252

53-
# # TODO: test for gradients
54-
# end
55-
# end
53+
# TODO: test for gradients
54+
end
55+
end
5656

5757
function ∇conv_data_filter(x, weight, conv_dims)
5858
dx, dweight = Enzyme.make_zero(x), Enzyme.make_zero(weight)
@@ -75,10 +75,8 @@ end
7575
x_reactant = Reactant.to_rarray(x)
7676

7777
@testset for stride in ((1, 1), (2, 2), (3, 3)),
78-
padding in ((0, 0), (1, 1), (2, 2), (0, 2), (2, 0), (0, 1), (1, 0)),
79-
dilation in ((1, 1), (2, 2), (1, 2), (2, 1))
80-
81-
@show groups, stride, padding, dilation
78+
padding in ((0, 0), (2, 2), (0, 2), (2, 0)),
79+
dilation in ((1, 1), (1, 2))
8280

8381
conv_dims = DenseConvDims(x, weight; stride, padding, dilation, groups)
8482

@@ -124,8 +122,6 @@ end
124122
end
125123
end
126124

127-
@info "Convolution done"
128-
129125
@testset "Batched Matrix Multiplication" begin
130126
Reactant.with_config(;
131127
convolution_precision=PrecisionConfig.HIGH,
@@ -157,8 +153,6 @@ end
157153
end
158154
end
159155

160-
@info "Batched Matrix Multiplication done"
161-
162156
@testset "Constant Padding: NNlib.pad_constant" begin
163157
x = rand(Float32, 4, 4)
164158
x_ra = Reactant.to_rarray(x)
@@ -205,8 +199,6 @@ end
205199
@test @jit(NNlib.pad_constant(x_ra, (1, 1))) NNlib.pad_constant(x, (1, 1))
206200
end
207201

208-
@info "Constant Padding done"
209-
210202
@testset "make_causal_mask" begin
211203
x = rand(2, 10)
212204
x_ra = Reactant.to_rarray(x)
@@ -217,8 +209,6 @@ end
217209
@test @jit(causal_mask2(x_ra)) causal_mask2(x)
218210
end
219211

220-
@info "make_causal_mask done"
221-
222212
# Adapted from https://github.com/FluxML/NNlib.jl/blob/02138682a4fc5ca019759218be50e59907d4527c/test/testsuite/gather.jl#L5
223213
@testset "NNlib gather" begin
224214
@testset "gather scalar index" begin
@@ -397,8 +387,6 @@ end
397387
end
398388
end
399389

400-
@info "Gather done"
401-
402390
# Adapted from https://github.com/FluxML/NNlib.jl/blob/1468582c4db5f18149cc8fff6fb4633c5debe5c5/test/testsuite/scatter.jl#L108
403391
@testset "NNlib scatter" begin
404392
function test_scatter(dsts, srcs, idxs, res; dims)
@@ -648,8 +636,6 @@ end
648636
end
649637
end
650638

651-
@info "Scatter done"
652-
653639
@testset "∇conv(D = $ndim)" for ndim in 1:3
654640
x_spatial_dim = 4
655641
batch_size = 2
@@ -683,8 +669,6 @@ end
683669
end
684670
end
685671

686-
@info "Convolution done"
687-
688672
@testset "Upsampling" begin
689673
x = randn(Float32, 4, 4, 3, 2)
690674
x_ra = Reactant.to_rarray(x)
@@ -726,8 +710,6 @@ end
726710
end
727711
end
728712

729-
@info "Upsampling done"
730-
731713
@testset "Pixel shuffle" begin
732714
x = [10i + j + channel / 10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1]
733715
x_ra = Reactant.to_rarray(x)
@@ -740,8 +722,6 @@ end
740722
@test @jit(NNlib.pixel_shuffle(y_ra, 2)) NNlib.pixel_shuffle(y, 2)
741723
end
742724

743-
@info "pixel_shuffle done"
744-
745725
@testset "softmax/logsoftmax reshaped input" begin
746726
x = rand(Float32, 3, 4, 5)
747727
x_ra = reshape(Reactant.to_rarray(x), 12, 5)
@@ -750,5 +730,3 @@ end
750730
@test @jit(NNlib.softmax(x_ra)) NNlib.softmax(x)
751731
@test @jit(NNlib.logsoftmax(x_ra)) NNlib.logsoftmax(x)
752732
end
753-
754-
@info "softmax/logsoftmax done"

test/runtests.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,13 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
5353
end
5454

5555
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"
56+
# @safetestset "NNlib Primitives" include("nn/nnlib.jl")
5657
# @safetestset "Flux.jl Integration" include("nn/flux.jl")
57-
@safetestset "NNlib Primitives" include("nn/nnlib.jl")
58-
@info "NNlib Primitives tests finished"
5958
if Sys.islinux()
60-
# @safetestset "Lux Integration" include("nn/lux.jl") # XXX: need to fix crash
61-
# @info "Lux Integration tests finished"
62-
@safetestset "LuxLib Primitives" include("nn/luxlib.jl") # XXX: TPU takes too long
59+
@safetestset "LuxLib Primitives" include("nn/luxlib.jl")
6360
@info "LuxLib Primitives tests finished"
61+
@safetestset "Lux Integration" include("nn/lux.jl")
62+
@info "Lux Integration tests finished"
6463
end
6564
end
6665
end

0 commit comments

Comments
 (0)