Skip to content

Commit 9df3631

Browse files
committed
Integrate NNPACK convolutions into tests a bit more
1 parent 52e0310 commit 9df3631

File tree

4 files changed

+19
-8
lines changed

4 files changed

+19
-8
lines changed

src/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ end
159159
if is_nnpack_available()
160160
function conv(x::Array{xT, 4}, w::Array{wT, 4},
161161
cdims::DenseConvDims{2, K, C_in, C_out, (1, 1), P, (1, 1), F};
162-
kwargs...) where {xT, wT, K, C_in, C_out, S, P, F}
162+
kwargs...) where {xT, wT, K, C_in, C_out, P, F}
163163
return conv_nnpack(x, w, cdims; kwargs...)
164164
end
165165
end

src/nnpack/interface.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,19 @@ end
5252

5353

5454
"""
55-
check_supported_operation(x::Array, pdims::PoolDims)
55+
nnpack_supported_operation(cdims::ConvDims)
56+
nnpack_supported_operation(pdims::PoolDims)
5657
57-
Returns `true` if nnpack supports the pooling operation for the given input.
58+
Returns `true` if nnpack supports the convolution/pooling operation for the given parameters.
5859
"""
59-
function check_supported_operation(x::Array{T, 4}, pdims::PoolDims{2, K, S, P, (1, 1)}) where {T, K, S, P}
60-
val = size(x)[1:2] .+ (P[1] + P[2], P[3] + P[4]) .- K
60+
function nnpack_supported_operation(pdims::PoolDims{2, K, S, P, (1, 1)}) where {K, S, P}
61+
val = input_size(pdims)[1:2] .+ (P[1] + P[2], P[3] + P[4]) .- K
6162
return val .% S == (0, 0) ? true : false
6263
end
64+
65+
function nnpack_supported_operation(cdims::ConvDims{2, K, (1, 1), P, (1, 1)}) where {K, S, P}
66+
return true
67+
end
68+
69+
# Return false for everything else
70+
nnpack_supported_operation(dims) = false

src/pooling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ end
132132
# Use NNPACK if it is available and operation is supported
133133
if is_nnpack_available()
134134
function maxpool(x::Array{T, 4}, pdims::PoolDims{2, K, S, P, (1, 1)}; kwargs...) where {T, K, S, P}
135-
func = check_supported_operation(x, pdims) ? maxpool_nnpack : maxpool_direct
135+
func = nnpack_supported_operation(pdims) ? maxpool_nnpack : maxpool_direct
136136
return func(x, pdims; kwargs...)
137137
end
138138
end

test/conv.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,13 @@ conv_answer_dict = Dict(
274274
# A "drop channels and batch dimension" helper
275275
ddims(x) = dropdims(x, dims=(rank+1, rank+2))
276276

277-
for conv in (NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct)
277+
for conv in (NNlib.conv, NNlib.conv_im2col, NNlib.conv_direct, NNlib.conv_nnpack)
278+
if conv == NNlib.conv_nnpack && !NNlib.nnpack_supported_operation(DenseConvDims(x, w))
279+
continue
280+
end
278281
@testset "$(conv)" begin
279-
# First, your basic convolution with no parameters
280282
cdims = DenseConvDims(x, w)
283+
# First, your basic convolution with no parameters
281284
@test isapprox(ddims(conv(x, w, cdims)), y_plain, rtol = 1.0e-7)
282285

283286
# Next, test convolution on views and alternate datatypes:

0 commit comments

Comments
 (0)