Skip to content

Commit ec79173

Browse files
authored
Add NNPACK support (#67)
Add NNPACK support
2 parents 40cee4b + ee86fbb commit ec79173

19 files changed

+670
-32
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77
*.dll
88
*~
99
\#*
10+
deps/usr
11+
deps.jl
12+
*.log

Manifest.toml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,22 @@
33
[[Base64]]
44
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
55

6+
[[BinaryProvider]]
7+
deps = ["Libdl", "Pkg", "SHA", "Test"]
8+
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
9+
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
10+
version = "0.5.3"
11+
612
[[Crayons]]
713
deps = ["Test"]
814
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
915
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
1016
version = "4.0.0"
1117

18+
[[Dates]]
19+
deps = ["Printf"]
20+
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
21+
1222
[[Distributed]]
1323
deps = ["Random", "Serialization", "Sockets"]
1424
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -17,6 +27,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1727
deps = ["Markdown"]
1828
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1929

30+
[[LibGit2]]
31+
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
32+
2033
[[Libdl]]
2134
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
2235

@@ -31,10 +44,18 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
3144
deps = ["Base64"]
3245
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
3346

47+
[[Pkg]]
48+
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
49+
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
50+
3451
[[Printf]]
3552
deps = ["Unicode"]
3653
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
3754

55+
[[REPL]]
56+
deps = ["InteractiveUtils", "Markdown", "Sockets"]
57+
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
58+
3859
[[Random]]
3960
deps = ["Serialization"]
4061
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -45,6 +66,9 @@ git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
4566
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
4667
version = "0.5.2"
4768

69+
[[SHA]]
70+
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
71+
4872
[[Serialization]]
4973
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
5074

@@ -69,5 +93,9 @@ git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
6993
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
7094
version = "0.5.0"
7195

96+
[[UUIDs]]
97+
deps = ["Random", "SHA"]
98+
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
99+
72100
[[Unicode]]
73101
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
33
version = "0.6.0"
44

55
[deps]
6+
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
67
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
Requires = "ae029012-a4dd-5104-9daa-d747884805df"

REQUIRE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
julia 1.0
22
Requires
3+
MacroTools
4+
BinaryProvider
35
TimerOutputs

deps/build.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using BinaryProvider
2+
3+
# Parse some basic command-line arguments
4+
const verbose = "--verbose" in ARGS
5+
const prefix = Prefix(get([a for a in ARGS if a != "--verbose"], 1, joinpath(@__DIR__, "usr")))
6+
products = [
7+
LibraryProduct(prefix, ["libnnpack"], :libnnpack),
8+
]
9+
10+
# Download binaries from hosted location
11+
bin_prefix = "https://github.com/JuliaPackaging/Yggdrasil/releases/download/NNPACK-v2018.06.22-0"
12+
13+
# Listing of files generated by BinaryBuilder:
14+
download_info = Dict(
15+
Linux(:aarch64, libc=:glibc) => ("$bin_prefix/NNPACK.v2018.6.22.aarch64-linux-gnu.tar.gz", "e0c6e21ba4c47acfd5a3d3e3510e8786474080f654338f4583b88860296c1437"),
16+
Linux(:i686, libc=:glibc) => ("$bin_prefix/NNPACK.v2018.6.22.i686-linux-gnu.tar.gz", "e9b6685001bc5a5d17acef15f3f6ffeb7beb6081926300f23ed4a442beac71ca"),
17+
Linux(:i686, libc=:musl) => ("$bin_prefix/NNPACK.v2018.6.22.i686-linux-musl.tar.gz", "36c1d3c30b3bc3e0b34f215945bb46319f88e28f011fc758f21ba888b1fd9e25"),
18+
MacOS(:x86_64) => ("$bin_prefix/NNPACK.v2018.6.22.x86_64-apple-darwin14.tar.gz", "b30046223a11470b15a2ceb0d0df6f7d8a43260fe52f4a2f8ebe5f0b2df822ca"),
19+
Linux(:x86_64, libc=:glibc) => ("$bin_prefix/NNPACK.v2018.6.22.x86_64-linux-gnu.tar.gz", "150d5b6ca81fa72bfdc8bbda2428f0d3483fd11a5813724646c6d6c6a7ef969f"),
20+
Linux(:x86_64, libc=:musl) => ("$bin_prefix/NNPACK.v2018.6.22.x86_64-linux-musl.tar.gz", "d961a104f814ec5b356519a82746a70a1df193ae37fc8130f38ffb61336def16"),
21+
)
22+
23+
# Install unsatisfied or updated dependencies:
24+
unsatisfied = any(!satisfied(p; verbose=verbose) for p in products)
25+
dl_info = choose_download(download_info, platform_key_abi())
26+
if dl_info === nothing && unsatisfied
27+
# If we don't have a compatible .tar.gz to download, complain.
28+
# Alternatively, you could attempt to install from a separate provider,
29+
# build from source or something even more ambitious here.
30+
error("Your platform (\"$(Sys.MACHINE)\", parsed as \"$(triplet(platform_key_abi()))\") is not supported by this package!")
31+
end
32+
33+
# If we have a download, and we are unsatisfied (or the version we're
34+
# trying to install is not itself installed) then load it up!
35+
if unsatisfied || !isinstalled(dl_info...; prefix=prefix)
36+
# Download and install binaries
37+
install(dl_info...; prefix=prefix, force=true, verbose=verbose)
38+
end
39+
40+
# Write out a deps.jl file that will contain mappings for our products
41+
write_deps_file(joinpath(@__DIR__, "deps.jl"), products, verbose=verbose)

src/NNlib.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,17 @@ using Requires, TimerOutputs
33

44
const to = TimerOutput()
55

6+
67
# Include APIs
78
include("dim_helpers.jl")
9+
10+
# NNPACK support
11+
if Sys.islinux() || Sys.isapple()
12+
include("nnpack/NNPACK.jl")
13+
else
14+
is_nnpack_available() = false
15+
end
16+
817
include("activation.jl")
918
include("softmax.jl")
1019
include("gemm.jl")

src/conv.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,13 @@ for backend in (Symbol(), :_direct, :_im2col)
151151
end
152152
end
153153
end
154+
155+
156+
# Use NNPACK if it is available and the operation is supported
157+
if is_nnpack_available()
158+
function conv(x::Array{xT, 4}, w::Array{wT, 4},
159+
cdims::DenseConvDims{2, K, C_in, C_out, (1, 1), P, (1, 1), F};
160+
kwargs...) where {xT, wT, K, C_in, C_out, S, P, F}
161+
return conv_nnpack(x, w, cdims; kwargs...)
162+
end
163+
end

src/dim_helpers.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,18 @@ function predilate(x::AbstractArray{T,N}, dilation::NTuple{M}) where {T, N, M}
119119
# zeros between each element of `x` along each spatial dimension.
120120
x_dil[(1:dilation[idx]:size(x_dil,idx) for idx in 1:(N-2))..., :, :] .= x
121121
return x_dil
122-
end
122+
end
123+
124+
"""
125+
flipweight(w::AbstractArray)
126+
127+
Reorders the weight tensor for supporting both convolution and cross-correlation operations.
128+
"""
129+
130+
# For any array with ndims <= 3 it makes no sense to flip the weights so simply return the
131+
# original array
132+
@inline flipweight(w::AbstractArray) = w
133+
134+
@inline flipweight(w::AbstractArray{T, 4}) where {T} = w[end:-1:1, end:-1:1, :, :]
135+
136+
@inline flipweight(w::AbstractArray{T, 5}) where {T} = w[end:-1:1, end:-1:1, end:-1:1, :, :]

src/nnpack/NNPACK.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
include("libnnpack_types.jl")
2+
include("error.jl")
3+
include("libnnpack.jl")
4+
include("performance.jl")
5+
include("interface.jl")
6+
7+
const depsjl_path = joinpath(dirname(@__FILE__), "..", "..", "deps", "deps.jl")
8+
if !isfile(depsjl_path)
9+
error("NNPACK not installed properly, run Pkg.build(\"NNlib\"), restart Julia and try again")
10+
end
11+
include(depsjl_path)
12+
13+
const shared_threadpool_dict = Dict{UInt64, Base.RefValue}()
14+
15+
"""
16+
is_nnpack_available()
17+
18+
Checks if the current hardware is supported by NNPACK.
19+
"""
20+
function is_nnpack_available()
21+
check_deps()
22+
status = nnp_initialize()
23+
if status == nnp_status_unsupported_hardware
24+
return false
25+
else
26+
return true
27+
end
28+
end
29+
30+
"""
31+
allocate_threadpool()
32+
33+
Allocates several threadpool based on the upper limit on the number of threads for the machine.
34+
Allows NNPACK to intelligently choose which threadpool to use for getting the best
35+
performance.
36+
"""
37+
function allocate_threadpool()
38+
global NNPACK_CPU_THREADS = NNPACK_CPU_THREADS > 8 ? UInt64(8) : floor(log2(NNPACK_CPU_THREADS))
39+
for i in 1:Int(NNPACK_CPU_THREADS)
40+
threads = UInt64(2^i)
41+
push!(shared_threadpool_dict, threads => Ref(pthreadpool_create(threads)))
42+
end
43+
end
44+
45+
@init begin
46+
check_deps()
47+
status = nnp_initialize()
48+
if status == nnp_status_unsupported_hardware
49+
@warn "Hardware is unsupported by NNPACK so falling back to default NNlib"
50+
end
51+
try
52+
global NNPACK_CPU_THREADS = parse(UInt64, ENV["NNPACK_CPU_THREADS"])
53+
catch
54+
# Sys.CPU_THREADS should be a better default if we are tuning the benchmark suite on
55+
# a particular machine. However, we fix the runtime threadpool here to have a max of
56+
# 4 threads so anything above will be ignored anyways
57+
global NNPACK_CPU_THREADS = UInt64(4)
58+
end
59+
allocate_threadpool()
60+
end

src/nnpack/error.jl

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
struct NNPACKError <: Exception
2+
code::nnp_status
3+
msg::AbstractString
4+
end
5+
6+
Base.show(io::IO, err::NNPACKError) = print(io, "NNPACKError(code $(err.code), $(err.msg))")
7+
8+
function NNPACKError(status::nnp_status)
9+
msg = "NNPACK STATUS SUCCESS"
10+
if status == nnp_status_invalid_batch_size
11+
msg = "NNPACK STATUS INVALID BATCH SIZE"
12+
elseif status == nnp_status_invalid_channels
13+
msg = "NNPACK STATUS INVALID CHANNELS"
14+
elseif status == nnp_status_invalid_input_channels
15+
msg = "NNPACK STATUS INVALID INPUT CHANNELS"
16+
elseif status == nnp_status_invalid_output_channels
17+
msg = "NNPACK STATUS INVALID OUTPUT CHANNELS"
18+
elseif status == nnp_status_invalid_input_size
19+
msg = "NNPACK STATUS INVALID INPUT SIZE"
20+
elseif status == nnp_status_invalid_input_stride
21+
msg = "NNPACK STATUS INVALID INPUT STRIDE"
22+
elseif status == nnp_status_invalid_input_padding
23+
msg = "NNPACK STATUS INVALID INPUT PADDING"
24+
elseif status == nnp_status_invalid_kernel_size
25+
msg = "NNPACK STATUS INVALID KERNEL SIZE"
26+
elseif status == nnp_status_invalid_pooling_size
27+
msg = "NNPACK STATUS INVALID POOLING SIZE"
28+
elseif status == nnp_status_invalid_pooling_stride
29+
msg = "NNPACK STATUS INVALID POOLING STRIDE"
30+
elseif status == nnp_status_invalid_algorithm
31+
msg = "NNPACK STATUS INVALID ALGORITHM"
32+
elseif status == nnp_status_invalid_transform_strategy
33+
msg = "NNPACK STATUS INVALID TRANSFORM STRATEGY"
34+
elseif status == nnp_status_invalid_output_subsampling
35+
msg = "NNPACK STATUS INVALID OUTPUT SUBSAMPLING"
36+
elseif status == nnp_status_invalid_activation
37+
msg = "NNPACK STATUS INVALID ACTIVATION"
38+
elseif status == nnp_status_invalid_activation_parameters
39+
msg = "NNPACK STATUS INVALID ACTIVATION PARAMETERS"
40+
elseif status == nnp_status_unsupported_input_size
41+
msg = "NNPACK STATUS UNSUPPORTED INPUT SIZE"
42+
elseif status == nnp_status_unsupported_input_stride
43+
msg = "NNPACK STATUS UNSUPPORTED INPUT STRIDE"
44+
elseif status == nnp_status_unsupported_input_padding
45+
msg = "NNPACK STATUS UNSUPPORTED INPUT PADDING"
46+
elseif status == nnp_status_unsupported_kernel_size
47+
msg = "NNPACK STATUS UNSUPPORTED KERNEL SIZE"
48+
elseif status == nnp_status_unsupported_pooling_size
49+
msg = "NNPACK STATUS UNSUPPORTED POOLING SIZE"
50+
elseif status == nnp_status_unsupported_pooling_stride
51+
msg = "NNPACK STATUS UNSUPPORTED POOLING STRIDE"
52+
elseif status == nnp_status_unsupported_algorithm
53+
msg = "NNPACK STATUS UNSUPPORTED ALGORITHM"
54+
elseif status == nnp_status_unsupported_transform_strategy
55+
msg = "NNPACK STATUS UNSUPPORTED TRANSFORM STRATEGY"
56+
elseif status == nnp_status_unsupported_activation
57+
msg = "NNPACK STATUS UNSUPPORTED ACTIVATION"
58+
elseif status == nnp_status_unsupported_activation_parameters
59+
msg = "NNPACK STATUS UNSUPPORTED ACTIVATION PARAMETERS"
60+
elseif status == nnp_status_uninitialized
61+
msg = "NNPACK STATUS UNINITIALIZED"
62+
elseif status == nnp_status_unsupported_hardware
63+
msg = "NNPACK STATUS UNSUPPORTED HARDWARE"
64+
elseif status == nnp_status_out_of_memory
65+
msg = "NNPACK STATUS OUT OF MEMORY"
66+
elseif status == nnp_status_insufficient_buffer
67+
msg = "NNPACK STATUS INSUFFICIENT BUFFER"
68+
elseif status == nnp_status_misaligned_buffer
69+
msg = "NNPACK STATUS MISALIGNED BUFFER"
70+
end
71+
NNPACKError(status, msg)
72+
end
73+
74+
macro nnpack_check(nnp_func)
75+
quote
76+
local err::nnp_status
77+
err = $(esc(nnp_func))
78+
if err != nnp_status_success
79+
throw(NNPACKError(err))
80+
end
81+
err
82+
end
83+
end

0 commit comments

Comments
 (0)