Skip to content

Commit a30c7a7

Browse files
author
Avik Pal
committed
Expose usage of NNPACK conv and maxpool operations
1 parent 5b35148 commit a30c7a7

File tree

5 files changed

+60
-11
lines changed

5 files changed

+60
-11
lines changed

src/NNlib.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,17 @@ using Requires, TimerOutputs
33

44
const to = TimerOutput()
55

6+
67
# Include APIs
78
include("dim_helpers.jl")
9+
10+
# NNPACK support
11+
if Sys.islinux()
12+
include("nnpack/NNPACK.jl")
13+
else
14+
is_nnpack_available() = false
15+
end
16+
817
include("activation.jl")
918
include("softmax.jl")
1019
include("gemm.jl")
@@ -24,10 +33,4 @@ include("impl/depthwiseconv_im2col.jl")
2433
# Direct implementations of pooling
2534
include("impl/pooling_direct.jl")
2635

27-
if Sys.islinux()
28-
include("nnpack/NNPACK.jl")
29-
else
30-
is_nnpack_available() = false
31-
end
32-
3336
end # module NNlib

src/conv.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,15 @@ for backend in (Symbol(), :_direct, :_im2col)
151151
end
152152
end
153153
end
154+
155+
156+
# Use NNPACK if it is available and the operation is supported
157+
if is_nnpack_available()
158+
function conv(x::Array{xT, 4}, w::Array{wT, 4},
159+
cdims::DenseConvDims{2, K, C_in, C_out, S, P, (1, 1), F};
160+
kwargs...) where {xT, wT, K, C_in, C_out, S, P, F}
161+
func = check_supported_operation(x, cdims) ? conv_nnpack :
162+
xT == wT ? conv_im2col : conv_direct
163+
return func(x, w, cdims; kwargs...)
164+
end
165+
end

src/nnpack/interface.jl

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ for (front_name, backend) in (
1010
@timeit_debug to function $(Symbol("$(front_name)$(backend)!"))(
1111
out::Array{T1,4}, in1::Array{T2,4}, in2::Array{T3,4},
1212
cdims::ConvDims; kwargs...) where {T1, T2, T3}
13-
@warn "Automatically converting $(size(in1)) input tensor to Float32" maxlog=1
13+
@warn "Automatically converting input tensor to Float32. This will have performance implications" maxlog=1
1414
# Output must of the same type as in the function signature
1515
T1.($(Symbol("$(front_name)$(backend)!"))(Float32.(out), Float32.(in1),
1616
Float32.(in2), cdims; kwargs...))
@@ -20,26 +20,26 @@ end
2020

2121

2222
function conv_nnpack(x::Array{T1, 4}, w::Array{T2, 4}, cdims::ConvDims; kwargs...) where {T1, T2}
23-
y = similar(x, output_size(cdims), channels_out(cdims), size(x, 4))
23+
y = similar(x, output_size(cdims)..., channels_out(cdims), size(x, 4))
2424
return conv_nnpack!(y, x, w, cdims; kwargs...)
2525
end
2626

2727

2828
function ∇conv_data(dy::Array{T1, 4}, w::Array{T2, 4}, cdims::ConvDims; kwargs...) where {T1, T2}
29-
dx = similar(dy, input_size(cdims), channels_in(cdims), size(dy, 4))
29+
dx = similar(dy, input_size(cdims)..., channels_in(cdims), size(dy, 4))
3030
return ∇conv_data!(dx, dy, w, cdims; kwargs...)
3131
end
3232

3333

3434
function ∇conv_filter(x::Array{T1, 4}, dy::Array{T2, 4}, cdims::ConvDims; kwargs...) where {T1, T2}
35-
dw = similar(x, kernel_size(cdims), channels_in(cdims), channels_out(cdims))
35+
dw = similar(x, kernel_size(cdims)..., channels_in(cdims), channels_out(cdims))
3636
return ∇conv_filter!(dw, x, dy, cdims; kwargs...)
3737
end
3838

3939

4040
function maxpool_nnpack!(y::Array{T1, 4}, x::Array{T2, 4}, pdims::PoolDims;
4141
kwargs...) where {T1, T2}
42-
@warn "Automatically converting $(size(x)) input tensor to Float32" maxlog=1
42+
@warn "Automatically converting input tensor to Float32. This will have performance implications" maxlog=1
4343
# We want the output to be of the same type as desired
4444
T1.(maxpool_nnpack!(Float32.(y), Float32.(x), pdims; kwargs...))
4545
end
@@ -49,3 +49,26 @@ function maxpool_nnpack(x::Array{T, 4}, pdims::PoolDims; kwargs...) where {T}
4949
y = similar(x, output_size(pdims)..., channels_out(pdims), size(x, 4))
5050
return maxpool_nnpack!(y, x, pdims; kwargs...)
5151
end
52+
53+
54+
"""
55+
check_supported_operation(x::Array, cdims::DenseConvDims)
56+
57+
Returns `true` if nnpack supports the convolution operation for the given input.
58+
"""
59+
function check_supported_operation(x::Array{T, 4}, cdims::DenseConvDims{2, K, C_in,
60+
C_out, S, P, (1, 1), F}) where {T, K, C_in, C_out, S, P, F}
61+
val = size(x)[1:2] .+ (P[1] + P[2], P[3] + P[4]) .- K
62+
return val .% S == (0, 0) ? true : false
63+
end
64+
65+
66+
"""
67+
check_supported_operation(x::Array, pdims::PoolDims)
68+
69+
Returns `true` if nnpack supports the pooling operation for the given input.
70+
"""
71+
function check_supported_operation(x::Array{T, 4}, pdims::PoolDims{2, K, S, P, (1, 1)}) where {T, K, S, P}
72+
val = size(x)[1:2] .+ (P[1] + P[2], P[3] + P[4]) .- K
73+
return val .% S == (0, 0) ? true : false
74+
end

src/nnpack/performance.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ function select_threadpool(pdims::PoolDims, batch_size::Int)
88
return shared_threadpool_dict[4][]
99
elseif batch_size >= 16 && inp_size >= 64
1010
return shared_threadpool_dict[4][]
11+
elseif inp_size <= 32
12+
return C_NULL
1113
elseif inp_size >= 128
1214
return shared_threadpool_dict[4][]
1315
elseif inp_size * batch_size >= 256

src/pooling.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,12 @@ for backend in (Symbol(), :_direct, :_im2col)
127127
end
128128
end
129129
end
130+
131+
132+
# Use NNPACK if it is available and operation is supported
133+
if is_nnpack_available()
134+
function maxpool(x::Array{T, 4}, pdims::PoolDims{2, K, S, P, (1, 1)}; kwargs...) where {T, K, S, P}
135+
func = check_supported_operation(x, pdims) ? maxpool_nnpack : maxpool_im2col
136+
return func(x, pdims; kwargs...)
137+
end
138+
end

0 commit comments

Comments
 (0)