Skip to content

Commit 7a3f865

Browse files
committed
improve performance for pooling in forward mode
1 parent 085adb7 commit 7a3f865

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

src/impl/pool.jl

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ function max_pooling2d_fwd!(x::AbstractArray{T,4}, y::AbstractArray{T,4},
22
width::Int, height::Int, channels::Int, num::Int, pooled_width::Int,
33
pooled_height::Int, kernel_w::Int, kernel_h::Int, pad_w::Int, pad_h::Int,
44
stride_w::Int, stride_h::Int) where T
5-
for n = 1:num, c = 1:channels, ph = 1:pooled_height, pw = 1:pooled_width
5+
@inbounds for n = 1:num, c = 1:channels, ph = 1:pooled_height, pw = 1:pooled_width
66
hstart = (ph - 1)*stride_h - pad_h
77
wstart = (pw - 1)*stride_w - pad_w
88
hend = min(hstart + kernel_h, height)
@@ -11,7 +11,11 @@ function max_pooling2d_fwd!(x::AbstractArray{T,4}, y::AbstractArray{T,4},
1111
hstart = max(hstart, 0) + 1
1212
wstart = max(wstart, 0) + 1
1313

14-
y[pw, ph, c, n] = maximum(x[wstart:wend, hstart:hend, c, n])
14+
m = typemin(T)
15+
for j in hstart:hend, i in wstart:wend
16+
m = max(x[i, j, c, n])
17+
end
18+
y[pw, ph, c, n] = m
1519
end
1620
end
1721

@@ -69,7 +73,7 @@ function mean_pooling2d_fwd!(x::AbstractArray{T,4}, y::AbstractArray{T,4},
6973
pooled_height::Int, kernel_w::Int, kernel_h::Int,pad_w::Int, pad_h::Int,
7074
stride_w::Int, stride_h::Int) where T
7175
kernel_size = kernel_w * kernel_h
72-
for n = 1:num, c = 1:channels, ph = 1:pooled_height, pw = 1:pooled_width
76+
@inbounds for n = 1:num, c = 1:channels, ph = 1:pooled_height, pw = 1:pooled_width
7377
hstart = (ph - 1) * stride_h - pad_h
7478
wstart = (pw - 1) * stride_w - pad_w
7579
hend = min(hstart + kernel_h, height)
@@ -78,7 +82,11 @@ function mean_pooling2d_fwd!(x::AbstractArray{T,4}, y::AbstractArray{T,4},
7882
hstart = max(hstart, 0) + 1
7983
wstart = max(wstart, 0) + 1
8084

81-
y[pw, ph, c, n] = sum(x[wstart:wend, hstart:hend, c, n]) / kernel_size
85+
s = zero(T)
86+
for j in hstart:hend, i in wstart:wend
87+
s += x[i, j, c, n]
88+
end
89+
y[pw, ph, c, n] = s / kernel_size
8290
end
8391
end
8492

@@ -132,7 +140,7 @@ function max_pooling3d_fwd!(x::AbstractArray{T,5}, y::AbstractArray{T,5},
132140
width::Int, height::Int, depth::Int, channels::Int, num::Int, pooled_width::Int,
133141
pooled_height::Int, pooled_depth::Int, kernel_w::Int, kernel_h::Int, kernel_d::Int,
134142
pad_w::Int, pad_h::Int, pad_d::Int, stride_w::Int, stride_h::Int, stride_d::Int) where T
135-
for n = 1:num, c = 1:channels, pd = 1:pooled_depth, ph = 1:pooled_height, pw = 1:pooled_width
143+
@inbounds for n = 1:num, c = 1:channels, pd = 1:pooled_depth, ph = 1:pooled_height, pw = 1:pooled_width
136144
dstart = (pd - 1)* stride_d - pad_d
137145
hstart = (ph - 1)* stride_h - pad_h
138146
wstart = (pw - 1)* stride_w - pad_w
@@ -145,8 +153,11 @@ function max_pooling3d_fwd!(x::AbstractArray{T,5}, y::AbstractArray{T,5},
145153
hstart = max(hstart, 0) + 1
146154
wstart = max(wstart, 0) + 1
147155

148-
y[pw, ph, pd, c, n] =
149-
maximum(x[wstart:wend, hstart:hend, dstart:dend, c, n])
156+
m = typemin(T)
157+
for k in dstart:dend, j in hstart:hend, i in wstart:wend
158+
m = max(x[i, j, k, c, n])
159+
end
160+
y[pw, ph, pd, c, n] = m
150161
end
151162
end
152163

@@ -213,7 +224,7 @@ function mean_pooling3d_fwd!(x::AbstractArray{T,5}, y::AbstractArray{T,5},
213224

214225
kernel_size = kernel_w * kernel_h * kernel_d
215226
#pragma omp parallel for
216-
for n = 1:num, c = 1:channels, pd = 1:pooled_depth, ph = 1:pooled_height, pw = 1:pooled_width
227+
@inbounds for n = 1:num, c = 1:channels, pd = 1:pooled_depth, ph = 1:pooled_height, pw = 1:pooled_width
217228
dstart = (pd - 1) * stride_d - pad_d
218229
hstart = (ph - 1) * stride_h - pad_h
219230
wstart = (pw - 1) * stride_w - pad_w
@@ -226,8 +237,11 @@ function mean_pooling3d_fwd!(x::AbstractArray{T,5}, y::AbstractArray{T,5},
226237
hstart = max(hstart, 0) + 1
227238
wstart = max(wstart, 0) + 1
228239

229-
y[pw, ph, pd, c, n] =
230-
sum(x[wstart:wend, hstart:hend, dstart:dend, c, n]) / kernel_size
240+
s = zero(T)
241+
for k in dstart:dend, j in hstart:hend, i in wstart:wend
242+
s += x[i, j, k, c, n]
243+
end
244+
y[pw, ph, pd, c, n] = s / kernel_size
231245
end
232246
end
233247

0 commit comments

Comments
 (0)