@@ -51,12 +51,12 @@ conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:A
51
51
52
52
function conv (x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where A<: AbstractArray{Float32, 4}
53
53
pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride)
54
+ y = similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_))
54
55
if fallback
55
- error ( " Unsupported Operation " )
56
+ conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation )
56
57
else
57
- y = similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_))
58
- b = zeros (Float32, size (y, 3 ))
59
- conv! (y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo))
58
+ @warn " Accessed"
59
+ conv! (y, x, w, zeros (Float32, size (y, 3 )), pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo))
60
60
end
61
61
end
62
62
@@ -65,20 +65,27 @@ conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) w
65
65
66
66
function conv (x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float32, 4} , A2<: AbstractArray{Float32, 1} }
67
67
pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride)
68
+ y = similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_))
68
69
if fallback
69
- error (" Unsupported Operation" )
70
- else
71
- conv! (similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_)), x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo))
70
+ conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation)
71
+ else
72
+ @warn " Accessed 2.0"
73
+ conv! (y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo))
72
74
end
73
75
end
74
76
75
- # crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
76
- # crosscor(Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
77
+ crosscor (x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float64, 4} , A2<: AbstractArray{Float64, 1} } =
78
+ crosscor (Float32 .(x), Float32 .(w), Float32 .(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
77
79
78
- # function crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
79
- # pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride)
80
- # conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)), x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo), flipkernel = 1)
81
- # end
80
+ function crosscor (x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float32, 4} , A2<: AbstractArray{Float32, 1} }
81
+ pad_, stride_, fallback = check_support (x, (size (w, 1 ), size (w, 2 )), pad, stride)
82
+ y = similar (x, cdims (size (x), dilation_dims (w, dilation), pad_, stride_))
83
+ if fallback
84
+ conv2d! (y, x, w, padding = pad, stride = stride, dilation = dilation, mode = 1 )
85
+ else
86
+ conv! (y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32 (algo), flipkernel = 1 )
87
+ end
88
+ end
82
89
83
90
conv! (y:: A1 , x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 ), flipkernel = 0 ) where {A1<: AbstractArray{Float64, 4} , A2<: AbstractArray{Float64, 1} } =
84
91
conv (Float32 .(y), Float32 .(x), Float32 .(w), Float32 .(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel)
@@ -88,11 +95,11 @@ function conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, al
88
95
nnp_convolution_output (y, x, w, b, algo = algo, padding = pad, stride = stride, threadpool = shared_threadpool[])
89
96
end
90
97
91
- # crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
92
- # conv!(Float32.(y), Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1)
98
+ crosscor! (y:: A1 , x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float64, 4} , A2<: AbstractArray{Float64, 1} } =
99
+ conv! (Float32 .(y), Float32 .(x), Float32 .(w), Float32 .(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1 )
93
100
94
- # crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}} =
95
- # conv!(y, x, w, b, pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1)
101
+ crosscor! (y:: A1 , x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where {A1<: AbstractArray{Float32, 4} , A2<: AbstractArray{Float32, 1} } =
102
+ conv! (y, x, w, b, pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1 )
96
103
97
104
∇conv_data (dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 )) where A<: AbstractArray{Float64, 4} =
98
105
∇conv_data (Float32 .(dy), Float32 .(x), Float32 .(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
0 commit comments