|
1 |
| -flipweight(w::AbstractArray{<:Any,4}) = w[end:-1:1,end:-1:1,:,:] |
| 1 | +flipweight(w::Array{<:Any,4}) = w[end:-1:1,end:-1:1,:,:] |
2 | 2 |
|
3 | 3 | function check_support(x, k, pad, stride, dilation = 1)
|
4 | 4 | fallback = false
|
|
11 | 11 | softmax!(y::A, x::A) where A<:AbstractVecOrMat{Float32} =
|
12 | 12 | nnp_softmax_output(x, y)
|
13 | 13 |
|
14 |
| -function maxpool(x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float32, 4} |
| 14 | +function maxpool!(y::A, x::A, k; pad = map(_->0,k), stride = k) where A<:Array{Float32, 4} |
15 | 15 | pad_, stride_, fallback = check_support(x, k, pad, stride)
|
16 | 16 | if fallback
|
17 |
| - maxpool_cpu!(similar(x, pdims(size(x), k, pad_, stride_)), x, k, pad = pad_, stride = stride_) |
| 17 | + maxpool_cpu!(y, x, k, pad = pad_, stride = stride_) |
18 | 18 | else
|
19 |
| - maxpool!(similar(x, pdims(size(x), k, pad_, stride_)), x, k, pad = pad_, stride = stride_) |
| 19 | + nnp_max_pooling_output(x, y, k, padding = expand(Val{length(k)}, pad), stride = expand(Val{length(k)}, stride)) |
20 | 20 | end
|
21 | 21 | end
|
22 | 22 |
|
23 |
| -maxpool!(y::A, x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float32, 4} = |
24 |
| - nnp_max_pooling_output(x, y, k, padding = expand(Val{length(k)}, pad), stride = expand(Val{length(k)}, stride)) |
25 |
| - |
26 |
| -function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4} |
| 23 | +function conv!(y::A1, x::A1, w::A1; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A1<:Array{Float32, 4} |
27 | 24 | pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
|
28 |
| - y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)) |
29 |
| - if fallback |
30 |
| - conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation) |
31 |
| - else |
32 |
| - conv!(y, x, w, zeros(Float32, size(y, 3)), pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo)) |
33 |
| - end |
34 |
| -end |
35 |
| - |
36 |
| -function conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}} |
37 |
| - pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation) |
38 |
| - y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)) |
39 |
| - if fallback |
40 |
| - conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation) |
41 |
| - else |
42 |
| - conv!(y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo)) |
43 |
| - end |
44 |
| -end |
45 |
| - |
46 |
| -function crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}} |
47 |
| - pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation) |
48 |
| - y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)) |
| 25 | + flipkernel == 0 && (w .= flipweight(w)) |
49 | 26 | if fallback
|
50 | 27 | conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode = 1)
|
51 | 28 | else
|
52 |
| - conv!(y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo), flipkernel = 1) |
53 |
| - end |
54 |
| -end |
55 |
| - |
56 |
| -function conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}} |
57 |
| - if flipkernel == 0 |
58 |
| - w = flipweight(w) |
| 29 | + nnp_convolution_output(y, x, w, zeros(Float32, size(y, 3)), algo = algo, padding = pad, stride = stride) |
59 | 30 | end
|
60 |
| - nnp_convolution_output(y, x, w, b, algo = algo, padding = pad, stride = stride) |
61 | 31 | end
|
62 | 32 |
|
63 |
| -crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}} = |
64 |
| - conv!(y, x, w, b, pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1) |
65 |
| - |
66 |
| -function ∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4} |
| 33 | +function ∇conv_data!(dx::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:Array{Float32, 4} |
67 | 34 | pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
|
68 | 35 | if fallback
|
69 |
| - conv2d_grad_x!(zeros(Float32, size(x)), x, w, dy, padding = pad_, stride = stride_, dilation = dilation) |
| 36 | + conv2d_grad_x!(dx, x, w, dy, padding = pad_, stride = stride_, dilation = dilation) |
70 | 37 | else
|
71 |
| - ∇conv_data!(zeros(Float32, size(x)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo)) |
| 38 | + flipkernel == 0 && (w .= flipweight(w)) |
| 39 | + nnp_convolution_input_gradient(dx, x, dy, w, padding = pad, stride = stride, algo = algo) |
72 | 40 | end
|
73 | 41 | end
|
74 | 42 |
|
75 |
| -function ∇conv_data!(dx::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float32, 4} |
76 |
| - flipkernel == 0 && (w = flipweight(w)) |
77 |
| - nnp_convolution_input_gradient(dx, x, dy, w, padding = pad, stride = stride, algo = algo) |
78 |
| -end |
79 |
| - |
80 |
| -function ∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4} |
| 43 | +function ∇conv_filter!(dw::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:Array{Float32, 4} |
81 | 44 | pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride, dilation)
|
82 | 45 | if fallback
|
83 |
| - conv2d_grad_w!(zeros(Float32, size(w)), x, w, dy, padding = pad_, stride = stride_, dilation = dilation) |
| 46 | + conv2d_grad_w!(dw, x, w, dy, padding = pad_, stride = stride_, dilation = dilation) |
84 | 47 | else
|
85 |
| - ∇conv_filter!(zeros(Float32, size(w)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo)) |
| 48 | + flipkernel == 0 && (w .= flipweight(w)) |
| 49 | + nnp_convolution_kernel_gradient(dw, x, dy, w, padding = pad, stride = stride, algo = algo) |
| 50 | + flipkernel && (dw .= flipkernel(dw)) |
| 51 | + dw |
86 | 52 | end
|
87 | 53 | end
|
88 |
| - |
89 |
| -function ∇conv_filter!(dw::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float32, 4} |
90 |
| - flipkernel == 0 && (w = flipweight(w)) |
91 |
| - nnp_convolution_kernel_gradient(dw, x, dy, w, padding = pad, stride = stride, algo = algo) |
92 |
| - flipkernel && (dw .= flipkernel(dw)) |
93 |
| - return dw |
94 |
| -end |
0 commit comments