Skip to content

Commit e2947d8

Browse files
author
Avik Pal
committed
Clean the code
1 parent 6025166 commit e2947d8

File tree

6 files changed

+115
-79
lines changed

6 files changed

+115
-79
lines changed

src/NNlib.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,15 @@ include("linalg.jl")
1313
include("conv.jl")
1414
include("cubroadcast.jl")
1515

16-
if Sys.islinux()
16+
try
17+
global ENABLE_NNPACK = parse(UInt64, ENV["ENABLE_NNPACK"])
18+
catch
19+
global ENABLE_NNPACK = 1
20+
end
21+
22+
if Sys.islinux() && ENABLE_NNPACK == 1
1723
include("nnpack/NNPACK.jl")
24+
include("backends.jl")
1825
end
1926

2027
end # module

src/backends.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
function nnpack_supported_operation(x::AbstractArray{<:Real, 4}, k, pad, stride, dilation)
2+
fallback = false
3+
# NNPACK does not support dilated convolutions
4+
dilation == 1 || dilation == (1, 1) || (fallback = true)
5+
# Expand the pad and stride to have same dimensions as k
6+
pad_, stride_ = expand(Val{length(k)}, pad), expand(Val{length(k)}, stride)
7+
(size(x, 1) - k[1] + 2 * pad_[1]) % stride_[1] == 0 && (size(x, 2) - k[2] + 2 * pad_[2]) % stride_[2] == 0 || (fallback = true)
8+
# Return the pad_ and stride_ as well
9+
return pad_, stride_, fallback
10+
end
11+
12+
function nnpack_speed_check(x::AbstractArray{<:Real, 4}, k, pad, stride, dilation)
13+
# Add heurestics here to determine whether or not to use NNPACK
14+
# For now just return true
15+
return true
16+
end
17+
18+
# NNPACK supports only Float32 operations. So Float64 will have it default behaviour
19+
20+
# Pooling
21+
function maxpool!(y::A, x::A, k; pad = map(_ -> 0, k), stride = k) where A<:Array{Float32, 4}
22+
pad_, stride_, use_default = nnpack_supported_operation(x, k, pad, stride, 1)
23+
use_nnpack = !use_default
24+
# Only use NNPACK if we get speed improvement
25+
use_nnpack && (use_nnpack = nnpack_speed_check(x, k, pad, stride, 1))
26+
if use_nnpack
27+
nnpack_max_pooling!(y, x, k, pad = pad_, stride = stride_)
28+
else
29+
maxpool_cpu!(y, x, k, pad = pad_, stride = stride_)
30+
end
31+
end
32+
33+
# Convolutions
34+
function conv!(y::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:Array{Float32, 4}
35+
k = (size(w, 1), size(w, 2))
36+
pad_, stride_, use_default = nnpack_supported_operation(x, k, pad, stride, 1)
37+
use_nnpack = !use_default
38+
use_nnpack && (use_nnpack = nnpack_speed_check(x, k, pad, stride, 1))
39+
if use_nnpack
40+
nnpack_convolution_forward!(y, x, w, zeros(Float32, size(y, 3)), algo = algo, pad = pad, stride = stride, flipkernel = flipkernel)
41+
else
42+
conv2d!(y, x, w, padding = pad_, stride = stride_, dilation = dilation, mode = flipkernel)
43+
end
44+
end
45+
46+
function ∇conv_data!(dx::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:Array{Float32, 4}
47+
k = (size(w, 1), size(w, 2))
48+
pad_, stride_, use_default = nnpack_supported_operation(x, k, pad, stride, 1)
49+
use_nnpack = !use_default
50+
use_nnpack && (use_nnpack = nnpack_speed_check(x, k, pad, stride, 1))
51+
if use_nnpack
52+
nnpack_convolution_backward_data!(dx, x, dy, w, pad = pad_, stride = stride_, algo = algo, flipkernel = flipkernel)
53+
else
54+
conv2d_grad_x!(dx, x, w, dy, padding = pad_, stride = stride_, dilation = dilation, mode = flipkernel)
55+
end
56+
end
57+
58+
function ∇conv_filter!(dw::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:Array{Float32, 4}
59+
k = (size(w, 1), size(w, 2))
60+
pad_, stride_, use_default = nnpack_supported_operation(x, k, pad, stride, 1)
61+
use_nnpack = !use_default
62+
use_nnpack && (use_nnpack = nnpack_speed_check(x, k, pad, stride, 1))
63+
if use_nnpack
64+
nnpack_convolution_backward_filter!(dw, x, dy, w, pad = pad_, stride = stride_, algo = algo, flipkernel = flipkernel)
65+
else
66+
conv2d_grad_w!(dw, x, w, dy, padding = pad_, stride = stride_, dilation = dilation, mode = flipkernel)
67+
end
68+
end

src/nnpack/NNPACK.jl

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,20 @@ end
99
include(depsjl_path)
1010

1111
const nnlib_interface_path = joinpath(dirname(@__FILE__), "interface.jl")
12-
# const shared_threadpool = Ref(C_NULL)
12+
const shared_threadpool = Ref(C_NULL)
1313

1414
@init begin
1515
check_deps()
16+
status = nnp_initialize()
17+
if status == nnp_status_unsupported_hardware
18+
@warn "HARDWARE is unsupported by NNPACK so falling back to default NNlib"
19+
else
20+
include(nnlib_interface_path)
21+
end
1622
try
17-
global ENABLE_NNPACK = parse(UInt64, ENV["ENABLE_NNPACK"])
23+
global NNPACK_CPU_THREADS = parse(UInt64, ENV["NNPACK_CPU_THREADS"])
1824
catch
19-
global ENABLE_NNPACK = 1
20-
end
21-
if ENABLE_NNPACK == 1
22-
status = nnp_initialize()
23-
if status == nnp_status_unsupported_hardware
24-
@warn "HARDWARE is unsupported by NNPACK so falling back to default NNlib"
25-
else
26-
include(nnlib_interface_path)
27-
end
28-
try
29-
global NNPACK_CPU_THREADS = parse(UInt64, ENV["NNPACK_CPU_THREADS"])
30-
catch
31-
global NNPACK_CPU_THREADS = 4
32-
end
33-
global shared_threadpool = pthreadpool_create(NNPACK_CPU_THREADS)
25+
global NNPACK_CPU_THREADS = 4
3426
end
27+
shared_threadpool = pthreadpool_create(NNPACK_CPU_THREADS)
3528
end

src/nnpack/interface.jl

Lines changed: 18 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,29 @@
11
flipweight(w::Array{<:Any,4}) = w[end:-1:1,end:-1:1,:,:]
22

3-
function check_support(x, k, pad, stride, dilation = 1)
4-
fallback = false
5-
dilation == 1 || dilation == (1, 1) || (fallback = true)
6-
pad_, stride_ = expand(Val{length(k)}, pad), expand(Val{length(k)}, stride)
7-
((size(x, 1) - k[1] + 2 * pad_[1]) % stride_[1] == 0 && (size(x, 2) - k[2] + 2 * pad_[2]) % stride_[2] == 0) || (fallback = true)
8-
return pad_, stride_, fallback
9-
end
10-
113
softmax!(y::A, x::A) where A<:AbstractVecOrMat{Float32} = nnp_softmax_output(x, y)
124

13-
function maxpool!(y::A, x::A, k; pad = map(_->0,k), stride = k) where A<:Array{Float32, 4}
14-
pad_, stride_, fallback = check_support(x, k, pad, stride)
15-
if fallback
16-
maxpool_cpu!(y, x, k, pad = pad_, stride = stride_)
17-
else
18-
nnp_max_pooling_output(x, y, k, padding = expand(Val{length(k)}, pad), stride = expand(Val{length(k)}, stride))
19-
end
5+
nnpack_max_pooling!(y::A, x::A, k; pad = 0, stride = 1) where A<:Array{Float32, 4} =
6+
nnp_max_pooling_output(y, x, k, padding = pad, stride = stride)
7+
8+
function nnpack_convolution_forward!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, algo = UInt32(0),
9+
flipkernel = 0) where {A1<:Array{Float32, 4}, A2<:Array{Float32, 1}}
10+
flipkernel == 0 && (w .= flipweight(w))
11+
# Use nnp_convolution_inference if the batch size is 1.
12+
# The wrapper for nnp_convolution_inference is not present so use nnp_convolution_output for now
13+
nnp_convolution_output(y, x, w, b, algo = algo, padding = pad, stride = stride)
2014
end
2115

22-
function conv!(y::A1, x::A1, w::A1; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A1<:Array{Float32, 4}
23-
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
16+
function nnpack_convolution_backward_data!(dx::A, dy::A, x::A, w::A; pad = 0, stride = 1,
17+
algo = UInt32(0), flipkernel = 0) where A<:Array{Float32, 4}
2418
flipkernel == 0 && (w .= flipweight(w))
25-
if fallback
26-
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode = 1)
27-
else
28-
nnp_convolution_output(y, x, w, zeros(Float32, size(y, 3)), algo = algo, padding = pad, stride = stride)
29-
end
19+
nnp_convolution_input_gradient(dx, x, dy, w, padding = pad, stride = stride, algo = algo)
3020
end
3121

32-
function ∇conv_data!(dx::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:Array{Float32, 4}
33-
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
34-
if fallback
35-
conv2d_grad_x!(dx, x, w, dy, padding = pad_, stride = stride_, dilation = dilation)
36-
else
37-
flipkernel == 0 && (w .= flipweight(w))
38-
nnp_convolution_input_gradient(dx, x, dy, w, padding = pad, stride = stride, algo = algo)
39-
end
22+
function nnpack_convolution_backward_filter!(dw::A, dy::A, x::A, w::A; pad = 0, stride = 1,
23+
algo = UInt32(0), flipkernel = 0) where A<:Array{Float32, 4}
24+
flipkernel == 0 && (w .= flipweight(w))
25+
nnp_convolution_kernel_gradient(dw, x, dy, w, padding = pad, stride = stride, algo = algo)
26+
flipkernel && (dw .= flipkernel(dw))
27+
dw
4028
end
4129

42-
function ∇conv_filter!(dw::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:Array{Float32, 4}
43-
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
44-
if fallback
45-
conv2d_grad_w!(dw, x, w, dy, padding = pad_, stride = stride_, dilation = dilation)
46-
else
47-
flipkernel == 0 && (w .= flipweight(w))
48-
nnp_convolution_kernel_gradient(dw, x, dy, w, padding = pad, stride = stride, algo = algo)
49-
flipkernel && (dw .= flipkernel(dw))
50-
dw
51-
end
52-
end

src/nnpack/libnnpack.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ function nnp_max_pooling_output(batch_size, channels, input_size, input_padding,
8181
@check ccall((:nnp_max_pooling_output, libnnpack), nnp_status, (Csize_t, Csize_t, nnp_size, nnp_padding, nnp_size, nnp_size, Ptr{Cfloat}, Ptr{Cfloat}, pthreadpool_t), batch_size, channels, input_size, input_padding, pooling_size, pooling_stride, input, output, threadpool)
8282
end
8383

84-
function nnp_max_pooling_output(x::Array{Float32,4}, y::Array{Float32,4}, kernel::Tuple; padding = 0, stride = 1, threadpool = shared_threadpool[])
84+
function nnp_max_pooling_output(y::Array{Float32,4}, x::Array{Float32,4}, kernel::Tuple; padding = 0, stride = 1, threadpool = shared_threadpool[])
8585
input_size = nnp_size(Csize_t.((size(x, 1), size(x, 2)))...)
8686
pooling_size = nnp_size(Csize_t.(kernel)...)
8787
input_padding = nnp_padding(Csize_t(padding[2]), Csize_t(padding[1]), Csize_t(padding[2]), Csize_t(padding[1]))

test/conv.jl

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,17 @@ using NNlib: conv, crosscor, ∇conv_filter, ∇conv_data, ∇maxpool, maxpool,
55
x = reshape(Float32[1:16;], 4, 4, 1, 1)
66
w = reshape(Float32[1:9;], 3, 3, 1, 1)
77

8-
# Fix these tests
9-
#=
10-
@test dropdims(conv(x, w), dims = (3,4)) == [
11-
29 79 129;
12-
39 89 139;
13-
49 99 149;
14-
59 109 159.]
8+
@test dropdims(conv(x, w), dims = (3,4)) == Float32.([
9+
192 372;
10+
237 417])
1511

16-
@test dropdims(conv(view(x, :, :, :, :), w), dims = (3,4)) == [
17-
29 79 129;
18-
39 89 139;
19-
49 99 149;
20-
59 109 159.]
21-
22-
@test dropdims(crosscor(x, w), dims = (3,4)) == [
23-
51 101 151;
24-
61 111 161;
25-
71 121 171;
26-
81 131 181.]
27-
=#
12+
@test dropdims(conv(view(x, :, :, :, :), w), dims = (3,4)) == Float32.([
13+
192 372;
14+
237 417])
15+
16+
@test dropdims(crosscor(x, w), dims = (3,4)) == Float32.([
17+
348.0 528.0;
18+
393.0 573.0])
2819

2920
@test dropdims(conv(x, w, pad=1), dims=(3,4)) Float32.([
3021
29 99 207 263

0 commit comments

Comments
 (0)