Skip to content

Commit cf02a05

Browse files
author
Avik Pal
committed
Support maxpool fallback
1 parent 8f8081a commit cf02a05

File tree

2 files changed

+48
-27
lines changed

2 files changed

+48
-27
lines changed

src/nnpack/interface.jl

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
function check_support(x, k, pad, stride, dilation = 0)
2-
dilation == 1 || dilation == (1, 1) || error("NNPACK does not support dilation > 1")
1+
function check_support(x, k, pad, stride, dilation = 1)
2+
fallback = false
3+
dilation == 1 || dilation == (1, 1) || (fallback = true)
34
pad_, stride_ = expand(Val{length(k)}, pad), expand(Val{length(k)}, stride)
4-
((size(x, 1) - k[1] + 2 * pad_[1]) % stride_[1] == 0 && (size(x, 2) - k[2] + 2 * pad_[2]) % stride_[2] == 0) || error("Choose the stride, pad and kernel size properly")
5-
return pad_, stride_
5+
((size(x, 1) - k[1] + 2 * pad_[1]) % stride_[1] == 0 && (size(x, 2) - k[2] + 2 * pad_[2]) % stride_[2] == 0) || (fallback = true)
6+
return pad_, stride_, fallback
67
end
78

89
#NOTE: Commenting out the activation functions until sure what to do
@@ -31,8 +32,12 @@ maxpool(x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float64,
3132
maxpool(Float32.(x), k, pad = pad, stride = stride)
3233

3334
function maxpool(x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float32, 4}
34-
pad_, stride_ = check_support(x, k, pad, stride)
35-
maxpool!(similar(x, pdims(size(x), k, pad_, stride_)), x, k, pad = pad_, stride = stride_)
35+
pad_, stride_, fallback = check_support(x, k, pad, stride)
36+
if fallback
37+
maxpool_cpu!(similar(x, pdims(size(x), k, pad_, stride_)), x, k, pad = pad_, stride = stride_)
38+
else
39+
maxpool!(similar(x, pdims(size(x), k, pad_, stride_)), x, k, pad = pad_, stride = stride_)
40+
end
3641
end
3742

3843
maxpool!(y::A, x::A, k; pad = map(_->0,k), stride = k) where A<:AbstractArray{Float64, 4} =
@@ -45,27 +50,35 @@ conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:A
4550
conv(Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
4651

4752
function conv(x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4}
48-
pad_, stride_ = check_support(x, k, pad, stride)
49-
y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_))
50-
b = zeros(Float32, size(y, 3))
51-
conv!(y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
53+
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride)
54+
if fallback
55+
error("Unsupported Operation")
56+
else
57+
y = similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_))
58+
b = zeros(Float32, size(y, 3))
59+
conv!(y, x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
60+
end
5261
end
5362

5463
conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
5564
conv(Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
5665

5766
function conv(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
58-
pad_, stride_ = check_support(x, k, pad, stride)
59-
conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)), x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
67+
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride)
68+
if fallback
69+
error("Unsupported Operation")
70+
else
71+
conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)), x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
72+
end
6073
end
6174

62-
crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
63-
crosscor(Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
75+
# crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
76+
# crosscor(Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo)
6477

65-
function crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
66-
pad_, stride_ = check_support(x, k, pad, stride)
67-
conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)), x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo), flipkernel = 1)
68-
end
78+
# function crosscor(x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
79+
# pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride)
80+
# conv!(similar(x, cdims(size(x), dilation_dims(w, dilation), pad_, stride_)), x, w, b, pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo), flipkernel = 1)
81+
# end
6982

7083
conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
7184
conv(Float32.(y), Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = flipkernel)
@@ -75,18 +88,22 @@ function conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, al
7588
nnp_convolution_output(y, x, w, b, algo = algo, padding = pad, stride = stride, threadpool = shared_threadpool[])
7689
end
7790

78-
crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
79-
conv!(Float32.(y), Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1)
91+
# crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float64, 4}, A2<:AbstractArray{Float64, 1}} =
92+
# conv!(Float32.(y), Float32.(x), Float32.(w), Float32.(b), pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1)
8093

81-
crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}} =
82-
conv!(y, x, w, b, pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1)
94+
# crosscor!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}} =
95+
# conv!(y, x, w, b, pad = pad, stride = stride, dilation = dilation, algo = algo, flipkernel = 1)
8396

8497
∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float64, 4} =
8598
∇conv_data(Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
8699

87100
function ∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4}
88-
pad_, stride_ = check_support(x, k, pad, stride)
89-
∇conv_data!(zeros(Float32, size(x)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
101+
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride)
102+
if fallback
103+
error("Unsupported Operation")
104+
else
105+
∇conv_data!(zeros(Float32, size(x)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
106+
end
90107
end
91108

92109
∇conv_data!(dx::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float64, 4} =
@@ -101,8 +118,12 @@ end
101118
∇conv_filter(Float32.(dy), Float32.(x), Float32.(w), pad = pad, stride = stride, dilation = dilation, algo = algo)
102119

103120
function ∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0)) where A<:AbstractArray{Float32, 4}
104-
pad_, stride_ = check_support(x, k, pad, stride)
105-
∇conv_filter!(zeros(Float32, size(w)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
121+
pad_, stride_, fallback = check_support(x, (size(w, 1), size(w, 2)), pad, stride)
122+
if fallback
123+
error("Unsupported Operation")
124+
else
125+
∇conv_filter!(zeros(Float32, size(w)), dy, x, w; pad = pad_, stride = stride_, dilation = dilation, algo = UInt32(algo))
126+
end
106127
end
107128

108129
∇conv_filter!(dw::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float64, 4} =

test/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ end
147147
# correctness of gradients is cross-checked with CUDNN.jl
148148
# (it's assumed maxpooling code won't change often)
149149

150-
y = maxpool(x, (2,2))
150+
y = Float64.(maxpool(x, (2,2)))
151151
dy = reshape(rand(2,2), 2, 2, 1, 1)
152152
@test size(∇maxpool(dy, y, x, (2,2))) == size(x)
153153

0 commit comments

Comments
 (0)