@@ -50,12 +50,11 @@ conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:A
50
50
conv (Float32 .(x), Float32 .(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
51
51
52
52
function conv (x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where A<: AbstractArray{Float32, 4}
53
- pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride)
53
+ pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride, dilation )
54
54
y = similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_))
55
55
if fallback
56
56
conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation)
57
57
else
58
- @warn " Accessed"
59
58
conv! (y, x, w, zeros (Float32, size (y, 3 )), pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo))
60
59
end
61
60
end
@@ -64,12 +63,11 @@ conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) w
64
63
conv (Float32 .(x), Float32 .(w), Float32 .(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
65
64
66
65
function conv (x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float32, 4} , A2<: AbstractArray{Float32, 1} }
67
- pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride)
66
+ pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride, dilation )
68
67
y = similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_))
69
68
if fallback
70
69
conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation)
71
70
else
72
- @warn " Accessed 2.0"
73
71
conv! (y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo))
74
72
end
75
73
end
@@ -78,7 +76,7 @@ crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0
78
76
crosscor (Float32 .(x), Float32 .(w), Float32 .(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
79
77
80
78
function crosscor (x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float32, 4} , A2<: AbstractArray{Float32, 1} }
81
- pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride)
79
+ pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride, dilation )
82
80
y = similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_))
83
81
if fallback
84
82
conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode = 1 )
@@ -88,10 +86,12 @@ function crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo =
88
86
end
89
87
90
88
conv! (y:: A1 , x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 ), flipkernel = 0 ) where {A1<: AbstractArray{Float64, 4} , A2<: AbstractArray{Float64, 1} } =
91
- conv (Float32 .(y), Float32 .(x), Float32 .(w), Float32 .(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel)
89
+ conv! (Float32 .(y), Float32 .(x), Float32 .(w), Float32 .(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel)
92
90
93
91
function conv! (y:: A1 , x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 ), flipkernel = 0 ) where {A1<: AbstractArray{Float32, 4} , A2<: AbstractArray{Float32, 1} }
94
- flipkernel == 0 && (w = reverse (reverse (w, dims= 1 ), dims= 2 ))
92
+ if flipkernel == 0
93
+ w = reverse (reverse (w, dims= 1 ), dims= 2 )
94
+ end
95
95
nnp_convolution_output (y, x, w, b, algo = algo, padding = pad, stride = stride, threadpool = shared_threadpool[])
96
96
end
97
97
@@ -105,9 +105,9 @@ crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo =
105
105
∇conv_data (Float32 .(dy), Float32 .(x), Float32 .(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
106
106
107
107
function ∇conv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where A<: AbstractArray{Float32, 4}
108
- pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride)
108
+ pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride, dilation )
109
109
if fallback
110
- error ( " Unsupported Operation " )
110
+ conv2d_grad_x! ( zeros (Float32, size (x)), x, w, dy, padding = pad_, stride = stride_, dilation = dilation )
111
111
else
112
112
∇conv_data! (zeros (Float32, size (x)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo))
113
113
end
125
125
∇conv_filter (Float32 .(dy), Float32 .(x), Float32 .(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
126
126
127
127
function ∇conv_filter (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where A<: AbstractArray{Float32, 4}
128
- pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride)
128
+ pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride, dilation )
129
129
if fallback
130
- error ( " Unsupported Operation " )
130
+ conv2d_grad_w! ( zeros (Float32, size (w)), x, w, dy, padding = pad_, stride = stride_, dilation = dilation )
131
131
else
132
132
∇conv_filter! (zeros (Float32, size (w)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo))
133
133
end
0 commit comments