Skip to content

Commit fba6d4a

Browse files
author
Avik Pal
committed
conv fallbacks
1 parent cf02a05 commit fba6d4a

File tree

1 file changed

+24
-17
lines changed

1 file changed

+24
-17
lines changed

src/nnpack/interface.jl

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:A
5151

5252
function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4}
5353
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride)
54+
y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_))
5455
if fallback
55-
error("Unsupported Operation")
56+
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
5657
else
57-
y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_))
58-
b = zeros(Float32, size(y, 3))
59-
conv!(y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
58+
@warn "Accessed"
59+
conv!(y, x, w, zeros(Float32, size(y, 3)), pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
6060
end
6161
end
6262

@@ -65,20 +65,27 @@ conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) w
6565

6666
function conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
6767
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride)
68+
y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_))
6869
if fallback
69-
error("Unsupported Operation")
70-
else
71-
conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)), x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
70+
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation)
71+
else
72+
@warn "Accessed 2.0"
73+
conv!(y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
7274
end
7375
end
7476

75-
# crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
76-
# crosscor(Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
77+
crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
78+
crosscor(Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
7779

78-
# function crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
79-
# pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride)
80-
# conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)), x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo), flipkernel = 1)
81-
# end
80+
function crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
81+
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride)
82+
y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_))
83+
if fallback
84+
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode = 1)
85+
else
86+
conv!(y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo), flipkernel = 1)
87+
end
88+
end
8289

8390
conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
8491
conv(Float32.(y), Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel)
@@ -88,11 +95,11 @@ function conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, al
8895
nnp_convolution_output(y, x, w, b, algo = algo, padding = pad, stride = stride, threadpool = shared_threadpool[])
8996
end
9097

91-
# crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
92-
# conv!(Float32.(y), Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1)
98+
crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
99+
conv!(Float32.(y), Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1)
93100

94-
# crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}} =
95-
# conv!(y, x, w, b, pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1)
101+
crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}} =
102+
conv!(y, x, w, b, pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1)
96103

97104
∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float64, 4} =
98105
∇conv_data(Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo)

0 commit comments

Comments
 (0)