Skip to content

Commit 838e88a

Browse files
author
Avik Pal
committed
Remove AbstractArray
1 parent f283696 commit 838e88a

File tree

1 file changed

+17
-58
lines changed

1 file changed

+17
-58
lines changed

src/nnpack/interface.jl

Lines changed: 17 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
flipweight(w::AbstractArray{<:Any,4}) = w[end:-1:1,end:-1:1,:,:]
1+
flipweight(w::Array{<:Any,4}) = w[end:-1:1,end:-1:1,:,:]
22

33
function check_support(x, k, pad, stride, dilation = 1)
44
fallback = false
@@ -11,84 +11,43 @@ end
1111
softmax!(y::A, x::A) where A<:AbstractVecOrMat{Float32} =
1212
nnp_softmax_output(x, y)
1313

14-
function maxpool(x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float32, 4}
14+
function maxpool!(y::A, x::A, k; pad = map(_->0,k), stride = k) where A<:Array{Float32, 4}
1515
pad_, stride_, fallback = check_support(x, k, pad, stride)
1616
if fallback
17-
maxpool_cpu!(similar(x, pdims(size(x), k, pad_, stride_)), x, k, pad = pad_, stride = stride_)
17+
maxpool_cpu!(y, x, k, pad = pad_, stride = stride_)
1818
else
19-
maxpool!(similar(x, pdims(size(x), k, pad_, stride_)), x, k, pad = pad_, stride = stride_)
19+
nnp_max_pooling_output(x, y, k, padding = expand(Val{length(k)}, pad), stride = expand(Val{length(k)}, stride))
2020
end
2121
end
2222

23-
maxpool!(y::A, x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float32, 4} =
24-
nnp_max_pooling_output(x, y, k, padding = expand(Val{length(k)}, pad), stride = expand(Val{length(k)}, stride))
25-
26-
function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4}
23+
function conv!(y::A1, x::A1, w::A1; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A1<:Array{Float32, 4}
2724
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
28-
y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_))
29-
if fallback
30-
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
31-
else
32-
conv!(y, x, w, zeros(Float32, size(y, 3)), pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
33-
end
34-
end
35-
36-
function conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
37-
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
38-
y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_))
39-
if fallback
40-
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
41-
else
42-
conv!(y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
43-
end
44-
end
45-
46-
function crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
47-
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
48-
y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_))
25+
flipkernel == 0 && (w .= flipweight(w))
4926
if fallback
5027
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode = 1)
5128
else
52-
conv!(y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo), flipkernel = 1)
53-
end
54-
end
55-
56-
function conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
57-
if flipkernel == 0
58-
w = flipweight(w)
29+
nnp_convolution_output(y, x, w, zeros(Float32, size(y, 3)), algo = algo, padding = pad, stride = stride)
5930
end
60-
nnp_convolution_output(y, x, w, b, algo = algo, padding = pad, stride = stride)
6131
end
6232

63-
crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}} =
64-
conv!(y, x, w, b, pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1)
65-
66-
function ∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4}
33+
function ∇conv_data!(dx::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:Array{Float32, 4}
6734
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
6835
if fallback
69-
conv2d_grad_x!(zeros(Float32, size(x)), x, w, dy, padding = pad_, stride = stride_, dilation = dilation)
36+
conv2d_grad_x!(dx, x, w, dy, padding = pad_, stride = stride_, dilation = dilation)
7037
else
71-
∇conv_data!(zeros(Float32, size(x)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
38+
flipkernel == 0 && (w .= flipweight(w))
39+
nnp_convolution_input_gradient(dx, x, dy, w, padding = pad, stride = stride, algo = algo)
7240
end
7341
end
7442

75-
function ∇conv_data!(dx::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float32, 4}
76-
flipkernel == 0 && (w = flipweight(w))
77-
nnp_convolution_input_gradient(dx, x, dy, w, padding = pad, stride = stride, algo = algo)
78-
end
79-
80-
function ∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4}
43+
function ∇conv_filter!(dw::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:Array{Float32, 4}
8144
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
8245
if fallback
83-
conv2d_grad_w!(zeros(Float32, size(w)), x, w, dy, padding = pad_, stride = stride_, dilation = dilation)
46+
conv2d_grad_w!(dw, x, w, dy, padding = pad_, stride = stride_, dilation = dilation)
8447
else
85-
∇conv_filter!(zeros(Float32, size(w)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
48+
flipkernel == 0 && (w .= flipweight(w))
49+
nnp_convolution_kernel_gradient(dw, x, dy, w, padding = pad, stride = stride, algo = algo)
50+
flipkernel && (dw .= flipkernel(dw))
51+
dw
8652
end
8753
end
88-
89-
function ∇conv_filter!(dw::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float32, 4}
90-
flipkernel == 0 && (w = flipweight(w))
91-
nnp_convolution_kernel_gradient(dw, x, dy, w, padding = pad, stride = stride, algo = algo)
92-
flipkernel && (dw .= flipkernel(dw))
93-
return dw
94-
end

0 commit comments

Comments
 (0)