Skip to content

Commit 398bfdd

Browse files
Merge pull request #154 from dhairyagandhi96/dg/infer_test
Refactor tests and make NNPACK optional
2 parents 480754a + 20e8e47 commit 398bfdd

File tree

6 files changed

+45
-30
lines changed

6 files changed

+45
-30
lines changed

deps/build.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,18 @@ end
3333

3434
# If we have a download, and we are unsatisfied (or the version we're
3535
# trying to install is not itself installed) then load it up!
36-
if unsatisfied || !isinstalled(dl_info...; prefix=prefix)
37-
# Download and install binaries
36+
# Download and install binaries
37+
use_nnpack = get(ENV, "NNLIB_USE_NNPACK", "false") == "true"
38+
os_support = Sys.islinux() || Sys.isapple()
39+
if use_nnpack && os_support
40+
if unsatisfied || !isinstalled(dl_info...; prefix=prefix)
3841
install(dl_info...; prefix=prefix, force=true, verbose=verbose)
42+
end
43+
# Write out a deps.jl file that will contain mappings for our products
44+
write_deps_file(joinpath(@__DIR__, "deps.jl"), products, verbose=verbose)
45+
else
46+
open(joinpath(@__DIR__, "deps.jl"), "w") do io
47+
write(io, "check_deps() = false")
48+
end
3949
end
4050

41-
# Write out a deps.jl file that will contain mappings for our products
42-
write_deps_file(joinpath(@__DIR__, "deps.jl"), products, verbose=verbose)

src/NNlib.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@ using Requires
55
include("dim_helpers.jl")
66

77
# NNPACK support
8-
if Sys.islinux() || Sys.isapple()
9-
include("nnpack/NNPACK.jl")
8+
include(joinpath(@__DIR__, "..", "deps", "deps.jl"))
9+
if check_deps() == nothing
10+
include("nnpack/NNPACK.jl")
1011
else
11-
is_nnpack_available() = false
12+
is_nnpack_available() = false
1213
end
1314

1415
include("activation.jl")

src/nnpack/NNPACK.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ const depsjl_path = joinpath(dirname(@__FILE__), "..", "..", "deps", "deps.jl")
88
if !isfile(depsjl_path)
99
error("NNPACK not installed properly, run Pkg.build(\"NNlib\"), restart Julia and try again")
1010
end
11-
include(depsjl_path)
1211

1312
const shared_threadpool_dict = Dict{UInt64, Base.RefValue}()
1413

@@ -18,7 +17,7 @@ const shared_threadpool_dict = Dict{UInt64, Base.RefValue}()
1817
Checks if the current hardware is supported by NNPACK.
1918
"""
2019
function is_nnpack_available()
21-
check_deps()
20+
check_deps() isa Nothing || return false
2221
status = nnp_initialize()
2322
if status == nnp_status_unsupported_hardware
2423
return false

test/conv.jl

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,13 @@ conv_answer_dict = Dict(
274274
# A "drop channels and batch dimension" helper
275275
ddims(x) = dropdims(x, dims=(rank+1, rank+2))
276276

277-
for conv in (NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct, NNlib.conv_nnpack)
278-
if conv == NNlib.conv_nnpack && !NNlib.nnpack_supported_operation(DenseConvDims(x, w))
279-
continue
277+
convs = [NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct,]
278+
NNlib.is_nnpack_available() && push!(convs, NNlib.conv_nnpack)
279+
for conv in convs
280+
if NNlib.is_nnpack_available()
281+
if conv == NNlib.conv_nnpack && !NNlib.nnpack_supported_operation(DenseConvDims(x, w))
282+
continue
283+
end
280284
end
281285
@testset "$(conv)" begin
282286
cdims = DenseConvDims(x, w)
@@ -352,12 +356,11 @@ conv_answer_dict = Dict(
352356
end
353357
end
354358
end
359+
end
355360

361+
if get(ENV,"NNLIB_TEST_FUZZING","false") == "true"
362+
# @info("Skipping Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them")
356363
@testset "fuzzing" begin
357-
if get(ENV,"NNLIB_TEST_FUZZING","false") != "true"
358-
@info("Skipping Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them")
359-
return
360-
end
361364
@info("Starting Convolutional fuzzing tests; this can take a few minutes...")
362365
# Now that we're fairly certain things are working, let's fuzz things a little bit:
363366
for x_size in (
@@ -441,9 +444,10 @@ conv_answer_dict = Dict(
441444
end
442445
println()
443446
end
447+
else
448+
@info "Skipping Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them"
444449
end
445450

446-
447451
@testset "Depthwise Convolution" begin
448452
# Start with some easy-to-debug cases that we have worked through and _know_ work
449453
for rank in (1,) #2,3)
@@ -552,12 +556,11 @@ end
552556
end
553557
end
554558
end
559+
end
560+
555561

562+
if get(ENV,"NNLIB_TEST_FUZZING","false") == "true"
556563
@testset "fuzzing" begin
557-
if get(ENV,"NNLIB_TEST_FUZZING","false") != "true"
558-
@info("Skipping Depthwise Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them")
559-
return
560-
end
561564
@info("Starting Depthwise Convolutional fuzzing tests; this can take a few minutes...")
562565
# Now that we're fairly certain things are working, let's fuzz things a little bit:
563566
for x_size in (
@@ -641,8 +644,11 @@ end
641644
end
642645
println()
643646
end
647+
else
648+
@info "Skipping Depthwise Convolutional fuzzing tests, set NNLIB_TEST_FUZZING=true to run them"
644649
end
645650

651+
646652
@testset "conv_wrapper" begin
647653
x = rand(10, 10, 3, 10)
648654
w = rand(2, 2, 3, 16)

test/inference.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using NNlib, Test
2-
using NNlib: conv_direct, conv_im2col
1+
import NNlib: conv_direct, conv_im2col
32

43
@testset "Conv Inference" begin
54
x = rand(10, 10, 3, 2)
@@ -9,6 +8,6 @@ using NNlib: conv_direct, conv_im2col
98
NNlib.is_nnpack_available() && push!(impl, NNlib.conv_nnpack)
109

1110
for T in impl
12-
@inferred T(x, w, DenseConvDims(x, w))
11+
@test T(x, w, DenseConvDims(x, w)) isa AbstractArray{K,4} where K
1312
end
1413
end

test/pooling.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NNlib, Test
1+
#using NNlib, Test
22

33
maxpool_answer_dict = Dict(
44
1 => Dict(
@@ -298,11 +298,13 @@ for rank in (1, 2, 3)
298298
end
299299
end
300300

301-
x = rand(10, 10, 3, 10)
302-
@test size(maxpool(x, (2, 2))) == (5, 5, 3, 10)
303-
@test size(maxpool(x, (2, 2); pad = (1, 1), stride = (2, 2))) == (6, 6, 3, 10)
304-
@test size(meanpool(x, (2, 2))) == (5, 5, 3, 10)
305-
@test size(meanpool(x, (2, 2); pad = (1, 1), stride = (2, 2))) == (6, 6, 3, 10)
301+
@testset "Pooling - Check Sizes" begin
302+
x = rand(10, 10, 3, 10)
303+
@test size(maxpool(x, (2, 2))) == (5, 5, 3, 10)
304+
@test size(maxpool(x, (2, 2); pad = (1, 1), stride = (2, 2))) == (6, 6, 3, 10)
305+
@test size(meanpool(x, (2, 2))) == (5, 5, 3, 10)
306+
@test size(meanpool(x, (2, 2); pad = (1, 1), stride = (2, 2))) == (6, 6, 3, 10)
307+
end
306308

307309
# Add another test for 2d maxpool that uses an odd-length size:
308310
@testset "Issue #133" begin

0 commit comments

Comments
 (0)