Skip to content

Commit 5449ebc

Browse files
author
Avik Pal
committed
Add tests for inference and remove timeroutputs
1 parent 5334c11 commit 5449ebc

File tree

8 files changed

+93
-87
lines changed

8 files changed

+93
-87
lines changed

src/impl/conv_direct.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ wrapper methods are available.
4444
"""
4545
conv_direct!
4646

47-
@timeit_debug to function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5},
47+
function conv_direct!(y::AbstractArray{yT,5}, x::AbstractArray{xT,5},
4848
w::AbstractArray{wT,5}, cdims::DenseConvDims;
4949
alpha::yT = yT(1), beta = false) where {yT, xT, wT}
5050
check_dims(size(x), size(w), size(y), cdims)
@@ -114,7 +114,7 @@ Calculate the gradient imposed upon `x` in the convolution `y = x * w`.
114114
"""
115115
∇conv_data_direct!
116116

117-
@timeit_debug to function ∇conv_data_direct!(dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5},
117+
function ∇conv_data_direct!(dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5},
118118
w::AbstractArray{wT,5}, cdims::DenseConvDims;
119119
alpha::xT=xT(1), beta=false) where {xT, yT, wT}
120120
w = transpose_swapbatch(w[end:-1:1, end:-1:1, end:-1:1, :, :])
@@ -133,7 +133,7 @@ Calculate the gradient imposed upon `w` in the convolution `y = x * w`.
133133
"""
134134
∇conv_filter_direct!
135135

136-
@timeit_debug to function ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5},
136+
function ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5},
137137
dy::AbstractArray{yT,5}, cdims::DenseConvDims;
138138
alpha::wT=wT(1), beta=false) where {xT, yT, wT}
139139
x = transpose_swapbatch(x[end:-1:1, end:-1:1, end:-1:1, :, :])

src/impl/conv_im2col.jl

Lines changed: 62 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ by setting `alpha` to a nonunitary value, various gain factors can be applied.
2222
Note for the particularly performance-minded, you can provide a pre-allocated `col`,
2323
which should eliminate any need for large allocations within this method.
2424
"""
25-
@timeit_debug to function conv_im2col!(
25+
function conv_im2col!(
2626
y::AbstractArray{T,5}, x::AbstractArray{T,5},
2727
w::AbstractArray{T,5}, cdims::DenseConvDims;
2828
col::AbstractArray{T,2}=similar(x, im2col_dims(cdims)),
@@ -49,12 +49,12 @@ which should eliminate any need for large allocations within this method.
4949
@inbounds for batch_idx in 1:size(x,5)
5050
# We invoke `@timeit_debug` on the outside of `im2col!()` because inference
5151
# doesn't like us putting it on the inside.
52-
@timeit_debug to "im2col!" im2col!(col, view(x, :, :, :, :, batch_idx), cdims)
52+
im2col!(col, view(x, :, :, :, :, batch_idx), cdims)
5353
GC.@preserve col, w, y, begin
5454
col_ptr = pointer(col)
5555
w_ptr = pointer(w)
5656
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
57-
@timeit_debug to "gemm!" gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
57+
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
5858
end
5959
end
6060
return y
@@ -66,7 +66,7 @@ end
6666
Conv backward pass onto the weights using im2col and GEMM; stores the result in `dw`.
6767
See the documentation for `conv_im2col!()` for explanation of optional parameters.
6868
"""
69-
@timeit_debug to function ∇conv_filter_im2col!(
69+
function ∇conv_filter_im2col!(
7070
dw::AbstractArray{T,5}, x::AbstractArray{T,5},
7171
dy::AbstractArray{T,5}, cdims::DenseConvDims;
7272
col::AbstractArray{T,2} = similar(dw, im2col_dims(cdims)),
@@ -95,14 +95,12 @@ See the documentation for `conv_im2col!()` for explanation of optional parameter
9595
K = prod(output_size(cdims))
9696

9797
@inbounds for batch_idx in 1:size(x,5)
98-
# We invoke `@timeit_debug` on the outside of `im2col!()` because inference
99-
# doesn't like us putting it on the inside.
100-
@timeit_debug to "im2col!" im2col!(col, view(x, :, :, :, :, batch_idx), cdims)
98+
im2col!(col, view(x, :, :, :, :, batch_idx), cdims)
10199
GC.@preserve col, dw, dy, begin
102100
col_ptr = pointer(col)
103101
dy_ptr = pointer(dy,(batch_idx - 1)*K*N + 1)
104102
dw_ptr = pointer(dw)
105-
@timeit_debug to "gemm!" gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
103+
gemm!(Val(true), Val(false), M, N, K, alpha, col_ptr, dy_ptr, beta, dw_ptr)
106104
end
107105

108106
# Because we accumulate over batches in this loop, we must set `beta` equal
@@ -118,7 +116,7 @@ end
118116
Conv2d backward pass onto the input using im2col and GEMM; stores the result in `dx`.
119117
See the documentation for `conv_im2col!()` for explanation of other parameters.
120118
"""
121-
@timeit_debug to function ∇conv_data_im2col!(
119+
function ∇conv_data_im2col!(
122120
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
123121
w::AbstractArray{T,5}, cdims::DenseConvDims;
124122
col::AbstractArray{T,2} = similar(dx, im2col_dims(cdims)),
@@ -149,9 +147,9 @@ See the documentation for `conv_im2col!()` for explanation of other parameters.
149147
dy_ptr = pointer(dy, (batch_idx - 1)*M*K + 1)
150148
w_ptr = pointer(w)
151149
col_ptr = pointer(col)
152-
@timeit_debug to "gemm!" gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
150+
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
153151
end
154-
@timeit_debug to "col2im!" col2im!(view(dx, :, :, :, :, batch_idx), col, cdims)
152+
col2im!(view(dx, :, :, :, :, batch_idx), col, cdims)
155153
end
156154
return dx
157155
end
@@ -207,77 +205,74 @@ function im2col!(col::AbstractArray{T,2}, x::AbstractArray{T,4},
207205
# We begin by copying the central region of the image which requires no padding at all.
208206
# Eliminating the branches of the fully generalized version below gives us a nice
209207
# speedup on the majority of the data.
210-
@timeit_debug to "im2col!() - central region" begin
211-
@inbounds for c in 1:C_in
212-
# Unpack "central region"
213-
w_region, h_region, d_region = central_region
214-
215-
for kd in 1:kernel_d,
216-
kh in 1:kernel_h,
217-
kw in 1:kernel_w,
218-
d in d_region,
219-
h in h_region,
220-
w in w_region
221-
222-
input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d
223-
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h
224-
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w
225-
kidxs = kernel_index(kw, kh, kd, cdims)
208+
@inbounds for c in 1:C_in
209+
# Unpack "central region"
210+
w_region, h_region, d_region = central_region
226211

227-
xval::T = x[input_kw, input_kh, input_kd, c]
228-
col_reshaped[w, h, d, kidxs..., c] = xval
229-
end
212+
for kd in 1:kernel_d,
213+
kh in 1:kernel_h,
214+
kw in 1:kernel_w,
215+
d in d_region,
216+
h in h_region,
217+
w in w_region
218+
219+
input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d
220+
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h
221+
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w
222+
kidxs = kernel_index(kw, kh, kd, cdims)
223+
224+
xval::T = x[input_kw, input_kh, input_kd, c]
225+
col_reshaped[w, h, d, kidxs..., c] = xval
230226
end
231227
end
232228

229+
233230
# For each "padded region", we run the fully general version
234-
@timeit_debug to "im2col!() - padded region" begin
235-
@inbounds for (w_region, h_region, d_region) in padded_regions
236-
for c in 1:C_in,
237-
d in d_region,
238-
h in h_region,
239-
w in w_region,
240-
kd in 1:kernel_d,
241-
kh in 1:kernel_h,
242-
kw in 1:kernel_w
231+
@inbounds for (w_region, h_region, d_region) in padded_regions
232+
for c in 1:C_in,
233+
d in d_region,
234+
h in h_region,
235+
w in w_region,
236+
kd in 1:kernel_d,
237+
kh in 1:kernel_h,
238+
kw in 1:kernel_w
243239

244-
input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d
245-
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h
246-
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w
240+
input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d
241+
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h
242+
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w
247243

248-
kidxs = kernel_index(kw, kh, kd, cdims)
244+
kidxs = kernel_index(kw, kh, kd, cdims)
249245

250-
# If this d is off the edge, then deal with the entire plane
251-
# in one fell swoop, like a ravenous flock of crows. CAW CAW.
252-
if input_kd <= 0 || input_kd > depth
253-
for kh in 1:kernel_h,
254-
kw in 1:kernel_w
255-
col_reshaped[w, h, d, kidxs..., c] = T(0)
256-
end
257-
continue
258-
end
259-
260-
# Same for `h`, but in this case it's only a line, not a plane.
261-
# This results in slightly less caw'ing.
262-
if input_kh <= 0 || input_kh > height
263-
for kw in 1:kernel_w
264-
col_reshaped[w, h, d, kidxs..., c] = T(0)
265-
end
266-
continue
246+
# If this d is off the edge, then deal with the entire plane
247+
# in one fell swoop, like a ravenous flock of crows. CAW CAW.
248+
if input_kd <= 0 || input_kd > depth
249+
for kh in 1:kernel_h,
250+
kw in 1:kernel_w
251+
col_reshaped[w, h, d, kidxs..., c] = T(0)
267252
end
253+
continue
254+
end
268255

269-
# If this `w` is off the edge it and only it gets cleared out
270-
if input_kw <= 0 || input_kw > width
256+
# Same for `h`, but in this case it's only a line, not a plane.
257+
# This results in slightly less caw'ing.
258+
if input_kh <= 0 || input_kh > height
259+
for kw in 1:kernel_w
271260
col_reshaped[w, h, d, kidxs..., c] = T(0)
272-
continue
273261
end
262+
continue
263+
end
274264

275-
# Copy the data over
276-
xval::T = x[input_kw, input_kh, input_kd, c]
277-
col_reshaped[w, h, d, kidxs..., c] = xval
265+
# If this `w` is off the edge it and only it gets cleared out
266+
if input_kw <= 0 || input_kw > width
267+
col_reshaped[w, h, d, kidxs..., c] = T(0)
268+
continue
278269
end
270+
271+
# Copy the data over
272+
xval::T = x[input_kw, input_kh, input_kd, c]
273+
col_reshaped[w, h, d, kidxs..., c] = xval
279274
end
280-
end
275+
end
281276
end
282277

283278

src/impl/depthwiseconv_direct.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ channels in `x` is the last, not the second-to-last, as in a normal dense convol
1818
1919
See the docstring for `conv_direct!()` for more on the optional parameters.
2020
"""
21-
@timeit_debug to function depthwiseconv_direct!(
21+
function depthwiseconv_direct!(
2222
y::AbstractArray{yT,5}, x::AbstractArray{xT,5},
2323
w::AbstractArray{wT,5}, cdims::DepthwiseConvDims;
2424
alpha::yT = yT(1), beta::yT = yT(0)) where {yT, xT, wT}
@@ -95,7 +95,7 @@ for each batch and channel independently.
9595
"""
9696
∇depthwiseconv_data_direct!
9797

98-
@timeit_debug to function ∇depthwiseconv_data_direct!(
98+
function ∇depthwiseconv_data_direct!(
9999
dx::AbstractArray{xT,5}, dy::AbstractArray{yT,5},
100100
w::AbstractArray{wT,5}, cdims::DepthwiseConvDims;
101101
alpha::xT=xT(1), beta::xT=xT(0)) where {xT, yT, wT}
@@ -128,7 +128,7 @@ Calculate the gradient imposed upon `w` in the depthwise convolution `y = x * w`
128128
"""
129129
∇depthwiseconv_filter_direct!
130130

131-
@timeit_debug to function ∇depthwiseconv_filter_direct!(
131+
function ∇depthwiseconv_filter_direct!(
132132
dw::AbstractArray{wT,5}, x::AbstractArray{xT,5},
133133
dy::AbstractArray{yT,5}, cdims::DepthwiseConvDims;
134134
alpha::wT=wT(1),beta::wT=wT(0)) where {xT, yT, wT}

src/impl/depthwiseconv_im2col.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ See `conv_im2col!()` for an explanation of optional parameters.
1010
"""
1111
depthwiseconv_im2col!
1212

13-
@timeit_debug to function depthwiseconv_im2col!(
13+
function depthwiseconv_im2col!(
1414
y::AbstractArray{T,5}, x::AbstractArray{T,5},
1515
w::AbstractArray{T,5}, cdims::DepthwiseConvDims;
1616
col::AbstractArray{T,2} = similar(x, im2col_dims(cdims)),
@@ -28,9 +28,7 @@ depthwiseconv_im2col!
2828

2929
dcdims = DenseConvDims(cdims)
3030
@inbounds for batch_idx in 1:size(x)[end]
31-
# We invoke `@timeit_debug` on the outside of `im2col!()` because inference
32-
# doesn't like us putting it on the inside.
33-
@timeit_debug to "im2col!" im2col!(col, view(x, :, :, :, :, batch_idx), dcdims)
31+
im2col!(col, view(x, :, :, :, :, batch_idx), dcdims)
3432

3533
# We do a separate convolution for each channel in x, as we must
3634
for c_in in 1:channels_in(cdims)
@@ -54,7 +52,7 @@ See the documentation for `conv_im2col!()` for explanation of optional parameter
5452
"""
5553
∇depthwiseconv_filter_im2col!
5654

57-
@timeit_debug to function ∇depthwiseconv_filter_im2col!(
55+
function ∇depthwiseconv_filter_im2col!(
5856
dw::AbstractArray{T,5}, x::AbstractArray{T,5},
5957
dy::AbstractArray{T,5}, cdims::DepthwiseConvDims;
6058
col::AbstractArray{T,2} = similar(dw, im2col_dims(cdims)),
@@ -66,9 +64,7 @@ See the documentation for `conv_im2col!()` for explanation of optional parameter
6664
K = prod(output_size(cdims))
6765

6866
@inbounds for batch_idx in 1:size(x)[end]
69-
# We invoke `@timeit_debug` on the outside of `im2col!()` because inference
70-
# doesn't like us putting it on the inside.
71-
@timeit_debug to "im2col!" im2col!(col, view(x, :, :, :, :, batch_idx), cdims)
67+
im2col!(col, view(x, :, :, :, :, batch_idx), cdims)
7268

7369
# We do a separate convolution for each channel in x, as we must
7470
for c_in in 1:channels_in(cdims)
@@ -96,7 +92,7 @@ See the documentation for `conv_im2col!()` for explanation of optional parameter
9692
"""
9793
∇depthwiseconv_data_im2col!
9894

99-
@timeit_debug to function ∇depthwiseconv_data_im2col!(
95+
function ∇depthwiseconv_data_im2col!(
10096
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
10197
w::AbstractArray{T,5}, cdims::DepthwiseConvDims;
10298
col::AbstractArray{T,2} = similar(dx, im2col_dims(cdims)),
@@ -118,7 +114,7 @@ See the documentation for `conv_im2col!()` for explanation of optional parameter
118114
gemm!(Val(false), Val(true), M, N, K, alpha, dy_ptr, w_ptr, T(0), col_ptr)
119115
end
120116
end
121-
@timeit_debug to "col2im!" col2im!(view(dx, :, :, :, :, batch_idx), col, cdims)
117+
col2im!(view(dx, :, :, :, :, batch_idx), col, cdims)
122118
end
123119
return dx
124120
end

src/nnpack/impl.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ function maxpool_nnpack!(y::A, x::A, pdims::PoolDims) where {A<:Array{Float32, 4
55
stride = stride(pdims), threadpool = threadpool)
66
end
77

8-
@timeit_debug to function conv_nnpack!(y::A1, x::A1, w::A1, cdims::ConvDims;
8+
function conv_nnpack!(y::A1, x::A1, w::A1, cdims::ConvDims;
99
b::A2 = zeros(Float32, size(x, 3)),
1010
algo = UInt32(0)) where {A1<:Array{Float32, 4},
1111
A2<:Array{Float32, 1}}
@@ -20,7 +20,7 @@ end
2020
stride = stride(cdims), threadpool = threadpool)
2121
end
2222

23-
@timeit_debug to function ∇conv_data_nnpack!(dx::A, dy::A, w::A, cdims::ConvDims;
23+
function ∇conv_data_nnpack!(dx::A, dy::A, w::A, cdims::ConvDims;
2424
algo = UInt32(0)) where{A<:Array{Float32, 4}}
2525
check_dims(size(dx), size(w), size(dy), cdims)
2626
threadpool = select_threadpool(cdims, size(y, 4))
@@ -33,7 +33,7 @@ end
3333
stride = stride(cdims), threadpool = threadpool)
3434
end
3535

36-
@timeit_debug to function ∇conv_filter_nnpack!(dw::A, x::A, dy::A, cdims::ConvDims;
36+
function ∇conv_filter_nnpack!(dw::A, x::A, dy::A, cdims::ConvDims;
3737
algo = UInt32(0)) where{A<:Array{Float32, 4}}
3838
check_dims(size(x), size(dw), size(dy), cdims)
3939
threadpool = select_threadpool(cdims, size(y, 4))

src/nnpack/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ for (front_name, backend) in (
77
:∇conv_filter => :_nnpack,
88
)
99
@eval begin
10-
@timeit_debug to function $(Symbol("$(front_name)$(backend)!"))(
10+
function $(Symbol("$(front_name)$(backend)!"))(
1111
out::Array{T1,4}, in1::Array{T2,4}, in2::Array{T3,4},
1212
cdims::ConvDims; kwargs...) where {T1, T2, T3}
1313
@warn "Automatically converting input tensor to Float32. This will have performance implications" maxlog=1

test/inference.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using NNlib, Test
2+
using NNlib: conv_direct, conv_im2col
3+
4+
@testset "Conv Inference" begin
5+
x = rand(10, 10, 3, 2)
6+
w = rand(3, 3, 3, 1)
7+
8+
impl = [conv, conv_direct, conv_im2col]
9+
NNlib.is_nnpack_available() && push!(impl, NNlib.conv_nnpack)
10+
11+
for T in impl
12+
@inferred T(x, w, DenseConvDims(x, w))
13+
end
14+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ using NNlib, Test
33
include("activation.jl")
44
include("conv.jl")
55
include("pooling.jl")
6+
include("inference.jl")

0 commit comments

Comments
 (0)