Skip to content

Commit f5cb955

Browse files
author
Dhairya Gandhi
committed
refactor tests
1 parent 480754a commit f5cb955

File tree

3 files changed

+24
-18
lines changed

3 files changed

+24
-18
lines changed

test/conv.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -352,12 +352,11 @@ conv_answer_dict = Dict(
352352
end
353353
end
354354
end
355+
end
355356

357+
if get(ENV,"NNLIB_TEST_FUZZING","false") == "true"
358+
# @info("Skipping Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them")
356359
@testset "fuzzing" begin
357-
if get(ENV,"NNLIB_TEST_FUZZING","false") != "true"
358-
@info("Skipping Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them")
359-
return
360-
end
361360
@info("Starting Convolutional fuzzing tests; this can take a few minutes...")
362361
# Now that we're fairly certain things are working, let's fuzz things a little bit:
363362
for x_size in (
@@ -441,9 +440,10 @@ conv_answer_dict = Dict(
441440
end
442441
println()
443442
end
443+
else
444+
@info "Skipping Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them"
444445
end
445446

446-
447447
@testset "Depthwise Convolution" begin
448448
# Start with some easy-to-debug cases that we have worked through and _know_ work
449449
for rank in (1,) #2,3)
@@ -552,12 +552,11 @@ end
552552
end
553553
end
554554
end
555+
end
555556

557+
558+
if get(ENV,"NNLIB_TEST_FUZZING","false") == "true"
556559
@testset "fuzzing" begin
557-
if get(ENV,"NNLIB_TEST_FUZZING","false") != "true"
558-
@info("Skipping Depthwise Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them")
559-
return
560-
end
561560
@info("Starting Depthwise Convolutional fuzzing tests; this can take a few minutes...")
562561
# Now that we're fairly certain things are working, let's fuzz things a little bit:
563562
for x_size in (
@@ -641,8 +640,11 @@ end
641640
end
642641
println()
643642
end
643+
else
644+
@info "Skipping Depthwise Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them"
644645
end
645646

647+
646648
@testset "conv_wrapper" begin
647649
x = rand(10, 10, 3, 10)
648650
w = rand(2, 2, 3, 16)

test/inference.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using NNlib, Test
2-
using NNlib: conv_direct, conv_im2col
1+
import NNlib: conv_direct, conv_im2col
32

43
@testset "Conv Inference" begin
54
x = rand(10, 10, 3, 2)
@@ -9,6 +8,9 @@ using NNlib: conv_direct, conv_im2col
98
NNlib.is_nnpack_available() && push!(impl, NNlib.conv_nnpack)
109

1110
for T in impl
12-
@inferred T(x, w, DenseConvDims(x, w))
11+
@test T(x, w, DenseConvDims(x, w)) isa AbstractArray{K,4} where K
1312
end
13+
14+
h() = error("check tests")
15+
@test h()
1416
end

test/pooling.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NNlib, Test
1+
#using NNlib, Test
22

33
maxpool_answer_dict = Dict(
44
1 => Dict(
@@ -298,11 +298,13 @@ for rank in (1, 2, 3)
298298
end
299299
end
300300

301-
x = rand(10, 10, 3, 10)
302-
@test size(maxpool(x, (2, 2))) == (5, 5, 3, 10)
303-
@test size(maxpool(x, (2, 2); pad = (1, 1), stride = (2, 2))) == (6, 6, 3, 10)
304-
@test size(meanpool(x, (2, 2))) == (5, 5, 3, 10)
305-
@test size(meanpool(x, (2, 2); pad = (1, 1), stride = (2, 2))) == (6, 6, 3, 10)
301+
@testset "Pooling - Check Sizes" begin
302+
x = rand(10, 10, 3, 10)
303+
@test size(maxpool(x, (2, 2))) == (5, 5, 3, 10)
304+
@test size(maxpool(x, (2, 2); pad = (1, 1), stride = (2, 2))) == (6, 6, 3, 10)
305+
@test size(meanpool(x, (2, 2))) == (5, 5, 3, 10)
306+
@test size(meanpool(x, (2, 2); pad = (1, 1), stride = (2, 2))) == (6, 6, 3, 10)
307+
end
306308

307309
# Add another test for 2d maxpool that uses an odd-length size:
308310
@testset "Issue #133" begin

0 commit comments

Comments
 (0)