Skip to content

Commit 8f8081a

Browse files
author
Avik Pal
committed
Remove common code
1 parent da3df3c commit 8f8081a

File tree

3 files changed

+38
-50
lines changed

3 files changed

+38
-50
lines changed

src/nnpack/NNPACK.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ if !isfile(depsjl_path)
88
end
99
include(depsjl_path)
1010

11-
const nnlib_interface_path = joinpath(dirname(@__FILE__), "nnlib.jl")
11+
const nnlib_interface_path = joinpath(dirname(@__FILE__), "interface.jl")
12+
global shared_threadpool = Ref(C_NULL)
13+
1214
@init begin
1315
check_deps()
1416
status = nnp_initialize()
@@ -22,5 +24,5 @@ const nnlib_interface_path = joinpath(dirname(@__FILE__), "nnlib.jl")
2224
catch
2325
global NNPACK_CPU_THREADS = 4
2426
end
25-
global shared_threadpool = Ref(pthreadpool_create(NNPACK_CPU_THREADS), 1)
27+
global shared_threadpool = Ref(pthreadpool_create(NNPACK_CPU_THREADS))
2628
end
Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,51 @@
1+
function check_support(x, k, pad, stride, dilation = 0)
2+
dilation == 1 || dilation == (1, 1) || error("NNPACK does not support dilation > 1")
3+
pad_, stride_ = expand(Val{length(k)}, pad), expand(Val{length(k)}, stride)
4+
((size(x, 1) - k[1] + 2 * pad_[1]) % stride_[1] == 0 && (size(x, 2) - k[2] + 2 * pad_[2]) % stride_[2] == 0) || error("Choose the stride, pad and kernel size properly")
5+
return pad_, stride_
6+
end
7+
18
#NOTE: Commenting out the activation functions until sure what to do
29

3-
# relu(x::AA1) = nnp_relu_output(x, inplace ? x : similar(x), threadpool = shared_threadpool)
10+
# relu(x::AA1) = nnp_relu_output(x, inplace ? x : similar(x), threadpool = shared_threadpool[])
411

512
# leakyrelu(x::AA1, a = oftype(x/1, 0.01)) =
6-
# nnp_relu_output(x, inplace ? x : similar(x), negative_slope = a, threadpool = shared_threadpool)
13+
# nnp_relu_output(x, inplace ? x : similar(x), negative_slope = a, threadpool = shared_threadpool[])
714

815
softmax!(x::A) where A<:AbstractVecOrMat{Float64} = softmax!(Float32.(x))
916

1017
softmax!(x::A) where A<:AbstractVecOrMat{Float32} =
11-
nnp_softmax_output(x, x, threadpool = shared_threadpool)
18+
nnp_softmax_output(x, x, threadpool = shared_threadpool[])
1219

1320
softmax!(y::A, x::A) where A<:AbstractVecOrMat{Float64} = softmax!(Float32.(y), Float32.(x))
1421

1522
softmax!(y::A, x::A) where A<:AbstractVecOrMat{Float32} =
16-
nnp_softmax_output(x, y, threadpool = shared_threadpool)
23+
nnp_softmax_output(x, y, threadpool = shared_threadpool[])
1724

1825
softmax(x::A) where A<:AbstractVecOrMat{Float64} = softmax(Float32.(x))
1926

2027
softmax(x::A) where A<:AbstractVecOrMat{Float32} =
21-
nnp_softmax_output(x, similar(x), threadpool = shared_threadpool)
28+
nnp_softmax_output(x, similar(x), threadpool = shared_threadpool[])
2229

2330
maxpool(x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float64, 4} =
2431
maxpool(Float32.(x), k, pad = pad, stride = stride)
2532

2633
function maxpool(x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float32, 4}
27-
pad_, stride_ = expand(Val{length(k)}, pad), expand(Val{length(k)}, stride)
28-
((size(x, 1) - k[1] + 2 * pad_[1]) % stride_[1] == 0 && (size(x, 2) - k[2] + 2 * pad_[2]) % stride_[2] == 0) || error("Choose the stride, pad and kernel size properly")
34+
pad_, stride_ = check_support(x, k, pad, stride)
2935
maxpool!(similar(x, pdims(size(x), k, pad_, stride_)), x, k, pad = pad_, stride = stride_)
3036
end
3137

3238
maxpool!(y::A, x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float64, 4} =
3339
maxpool!(Float32.(y), Float32.(x), k, pad = pad, stride = stride)
3440

3541
maxpool!(y::A, x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float32, 4} =
36-
nnp_max_pooling_output(x, y, k, padding = expand(Val{length(k)}, pad), stride = expand(Val{length(k)}, stride), threadpool = shared_threadpool)
42+
nnp_max_pooling_output(x, y, k, padding = expand(Val{length(k)}, pad), stride = expand(Val{length(k)}, stride), threadpool = shared_threadpool[])
3743

3844
conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float64, 4} =
3945
conv(Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
4046

4147
function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4}
42-
dilation == 1 || dilation == (1, 1) || error("NNPACK does not support dilation > 1")
43-
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
44-
((size(x, 1) - size(w, 1) + 2 * pad_[1]) % stride_[1] == 0 && (size(x, 2) - size(w, 2) + 2 * pad_[2]) % stride_[2] == 0) || error("Choose the stride, pad and kernel size properly")
48+
pad_, stride_ = check_support(x, k, pad, stride)
4549
y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_))
4650
b = zeros(Float32, size(y, 3))
4751
conv!(y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
@@ -51,19 +55,15 @@ conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) w
5155
conv(Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
5256

5357
function conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
54-
dilation == 1 || dilation == (1, 1) || error("NNPACK does not support dilation > 1")
55-
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
56-
((size(x, 1) - size(w, 1) + 2 * pad_[1]) % stride_[1] == 0 && (size(x, 2) - size(w, 2) + 2 * pad_[2]) % stride_[2] == 0) || error("Choose the stride, pad and kernel size properly")
58+
pad_, stride_ = check_support(x, k, pad, stride)
5759
conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)), x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
5860
end
5961

6062
crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
6163
crosscor(Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
6264

6365
function crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
64-
dilation == 1 || dilation == (1, 1) || error("NNPACK does not support dilation > 1")
65-
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
66-
((size(x, 1) - size(w, 1) + 2 * pad_[1]) % stride_[1] == 0 && (size(x, 2) - size(w, 2) + 2 * pad_[2]) % stride_[2] == 0) || error("Choose the stride, pad and kernel size properly")
66+
pad_, stride_ = check_support(x, k, pad, stride)
6767
conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)), x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo), flipkernel = 1)
6868
end
6969

@@ -72,7 +72,7 @@ conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt
7272

7373
function conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
7474
flipkernel == 0 && (w = reverse(reverse(w, dims=1), dims=2))
75-
nnp_convolution_output(y, x, w, b, algo = algo, padding = pad, stride = stride, threadpool = shared_threadpool)
75+
nnp_convolution_output(y, x, w, b, algo = algo, padding = pad, stride = stride, threadpool = shared_threadpool[])
7676
end
7777

7878
crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
@@ -85,9 +85,7 @@ crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo =
8585
∇conv_data(Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
8686

8787
function ∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4}
88-
dilation == 1 || dilation == (1, 1) || error("NNPACK does not support dilation > 1")
89-
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
90-
((size(x, 1) - size(w, 1) + 2 * pad_[1]) % stride_[1] == 0 && (size(x, 2) - size(w, 2) + 2 * pad_[2]) % stride_[2] == 0) || error("Choose the stride, pad and kernel size properly")
88+
pad_, stride_ = check_support(x, k, pad, stride)
9189
∇conv_data!(zeros(Float32, size(x)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
9290
end
9391

@@ -96,16 +94,14 @@ end
9694

9795
function ∇conv_data!(dx::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float32, 4}
9896
flipkernel == 0 && (w = reverse(reverse(w, dims=1), dims=2))
99-
nnp_convolution_input_gradient(dx, x, dy, w, padding = pad, stride = stride, algo = algo, threadpool = shared_threadpool)
97+
nnp_convolution_input_gradient(dx, x, dy, w, padding = pad, stride = stride, algo = algo, threadpool = shared_threadpool[])
10098
end
10199

102100
∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float64, 4} =
103101
∇conv_filter(Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
104102

105103
function ∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4}
106-
dilation == 1 || dilation == (1, 1) || error("NNPACK does not support dilation > 1")
107-
pad_, stride_ = padtuple(x, pad), padtuple(x, stride)
108-
((size(x, 1) - size(w, 1) + 2 * pad_[1]) % stride_[1] == 0 && (size(x, 2) - size(w, 2) + 2 * pad_[2]) % stride_[2] == 0) || error("Choose the stride, pad and kernel size properly")
104+
pad_, stride_ = check_support(x, k, pad, stride)
109105
∇conv_filter!(zeros(Float32, size(w)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
110106
end
111107

@@ -114,6 +110,6 @@ end
114110

115111
function ∇conv_filter!(dw::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float32, 4}
116112
flipkernel == 0 && (w = reverse(reverse(w, dims=1), dims=2))
117-
dw .= nnp_convolution_kernel_gradient(dw, x, dy, w, padding = pad, stride = stride, algo = algo, threadpool = shared_threadpool)
113+
dw .= nnp_convolution_kernel_gradient(dw, x, dy, w, padding = pad, stride = stride, algo = algo, threadpool = shared_threadpool[])
118114
flipkernel == 0 ? reverse(reverse(dw, dims=1), dims=2) : dw
119115
end

0 commit comments

Comments
 (0)