1
- function check_support (x, k, pad, stride, dilation = 0 )
2
- dilation == 1 || dilation == (1 , 1 ) || error (" NNPACK does not support dilation > 1" )
1
+ function check_support (x, k, pad, stride, dilation = 1 )
2
+ fallback = false
3
+ dilation == 1 || dilation == (1 , 1 ) || (fallback = true )
3
4
pad_, stride_ = expand (Val{length (k)}, pad), expand (Val{length (k)}, stride)
4
- ((size (x, 1 ) - k[1 ] + 2 * pad_[1 ]) % stride_[1 ] == 0 && (size (x, 2 ) - k[2 ] + 2 * pad_[2 ]) % stride_[2 ] == 0 ) || error ( " Choose the stride, pad and kernel size properly " )
5
- return pad_, stride_
5
+ ((size (x, 1 ) - k[1 ] + 2 * pad_[1 ]) % stride_[1 ] == 0 && (size (x, 2 ) - k[2 ] + 2 * pad_[2 ]) % stride_[2 ] == 0 ) || (fallback = true )
6
+ return pad_, stride_, fallback
6
7
end
7
8
8
9
# NOTE: Commenting out the activation functions until sure what to do
@@ -31,8 +32,12 @@ maxpool(x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float64,
31
32
maxpool (Float32 .(x), k, pad = pad, stride = stride)
32
33
33
34
function maxpool (x:: A , k; pad = map (_-> 0 ,k), stride = k) where A<: AbstractArray{Float32, 4}
34
- pad_, stride_ = check_support (x, k, pad, stride)
35
- maxpool! (similar (x, pdims (size (x), k, pad_, stride_)), x, k, pad = pad_, stride = stride_)
35
+ pad_, stride_, fallback = check_support (x, k, pad, stride)
36
+ if fallback
37
+ maxpool_cpu! (similar (x, pdims (size (x), k, pad_, stride_)), x, k, pad = pad_, stride = stride_)
38
+ else
39
+ maxpool! (similar (x, pdims (size (x), k, pad_, stride_)), x, k, pad = pad_, stride = stride_)
40
+ end
36
41
end
37
42
38
43
maxpool! (y:: A , x:: A , k; pad = map (_-> 0 ,k), stride = k) where A<: AbstractArray{Float64, 4} =
@@ -45,27 +50,35 @@ conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:A
45
50
conv (Float32 .(x), Float32 .(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
46
51
47
52
function conv (x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where A<: AbstractArray{Float32, 4}
48
- pad_, stride_ = check_support (x, k, pad, stride)
49
- y = similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_))
50
- b = zeros (Float32, size (y, 3 ))
51
- conv! (y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo))
53
+ pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride)
54
+ if fallback
55
+ error (" Unsupported Operation" )
56
+ else
57
+ y = similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_))
58
+ b = zeros (Float32, size (y, 3 ))
59
+ conv! (y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo))
60
+ end
52
61
end
53
62
54
63
conv (x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float64, 4} , A2<: AbstractArray{Float64, 1} } =
55
64
conv (Float32 .(x), Float32 .(w), Float32 .(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
56
65
57
66
function conv (x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float32, 4} , A2<: AbstractArray{Float32, 1} }
58
- pad_, stride_ = check_support (x, k, pad, stride)
59
- conv! (similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_)), x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo))
67
+ pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride)
68
+ if fallback
69
+ error (" Unsupported Operation" )
70
+ else
71
+ conv! (similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_)), x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo))
72
+ end
60
73
end
61
74
62
- crosscor (x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float64, 4} , A2<: AbstractArray{Float64, 1} } =
63
- crosscor (Float32 .(x), Float32 .(w), Float32 .(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
75
+ # crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
76
+ # crosscor(Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
64
77
65
- function crosscor (x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float32, 4} , A2<: AbstractArray{Float32, 1} }
66
- pad_, stride_ = check_support (x, k , pad, stride)
67
- conv! (similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_)), x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo), flipkernel = 1 )
68
- end
78
+ # function crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
79
+ # pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)) , pad, stride)
80
+ # conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)), x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo), flipkernel = 1)
81
+ # end
69
82
70
83
conv! (y:: A1 , x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 ), flipkernel = 0 ) where {A1<: AbstractArray{Float64, 4} , A2<: AbstractArray{Float64, 1} } =
71
84
conv (Float32 .(y), Float32 .(x), Float32 .(w), Float32 .(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel)
@@ -75,18 +88,22 @@ function conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, al
75
88
nnp_convolution_output (y, x, w, b, algo = algo, padding = pad, stride = stride, threadpool = shared_threadpool[])
76
89
end
77
90
78
- crosscor! (y:: A1 , x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float64, 4} , A2<: AbstractArray{Float64, 1} } =
79
- conv! (Float32 .(y), Float32 .(x), Float32 .(w), Float32 .(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1 )
91
+ # crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
92
+ # conv!(Float32.(y), Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1)
80
93
81
- crosscor! (y:: A1 , x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float32, 4} , A2<: AbstractArray{Float32, 1} } =
82
- conv! (y, x, w, b, pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1 )
94
+ # crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}} =
95
+ # conv!(y, x, w, b, pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1)
83
96
84
97
∇conv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where A<: AbstractArray{Float64, 4} =
85
98
∇conv_data (Float32 .(dy), Float32 .(x), Float32 .(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
86
99
87
100
function ∇conv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where A<: AbstractArray{Float32, 4}
88
- pad_, stride_ = check_support (x, k, pad, stride)
89
- ∇conv_data! (zeros (Float32, size (x)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo))
101
+ pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride)
102
+ if fallback
103
+ error (" Unsupported Operation" )
104
+ else
105
+ ∇conv_data! (zeros (Float32, size (x)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo))
106
+ end
90
107
end
91
108
92
109
∇conv_data! (dx:: A , dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 ), flipkernel = 0 ) where A<: AbstractArray{Float64, 4} =
101
118
∇conv_filter (Float32 .(dy), Float32 .(x), Float32 .(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
102
119
103
120
function ∇conv_filter (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where A<: AbstractArray{Float32, 4}
104
- pad_, stride_ = check_support (x, k, pad, stride)
105
- ∇conv_filter! (zeros (Float32, size (w)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo))
121
+ pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride)
122
+ if fallback
123
+ error (" Unsupported Operation" )
124
+ else
125
+ ∇conv_filter! (zeros (Float32, size (w)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo))
126
+ end
106
127
end
107
128
108
129
∇conv_filter! (dw:: A , dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 ), flipkernel = 0 ) where A<: AbstractArray{Float64, 4} =
0 commit comments