Skip to content

Commit f3708f9

Browse files
committed
maxpool and meanpool added
1 parent 1584f86 commit f3708f9

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

src/pooling.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ for backend in (Symbol(), :_direct, :_im2col)
114114
fill!(y, xT(0))
115115
return $(Symbol("$(name)$(backend)!"))(y, x, pdims; kwargs...)
116116
end
117-
117+
118118
# Backprops too
119119
@timeit_debug to function $(Symbol("$(name)$(backend)"))(
120120
dy::AbstractArray{T,N}, y::AbstractArray{T,N},
@@ -136,3 +136,20 @@ if is_nnpack_available()
136136
return func(x, pdims; kwargs...)
137137
end
138138
end
139+
140+
expand(N, i::Tuple) = i
141+
expand(N, i::Integer) = ntuple(_ -> i, N)
142+
143+
function maxpool(x, k::NTuple{N, Integer}; pad = 0, stride = k) where N
144+
pad = expand(Val(2*N), pad)
145+
stride = expand(Val(N), stride)
146+
pdims = PoolDims(x, k; padding = pad, stride = stride)
147+
return maxpool(x, pdims)
148+
end
149+
150+
function meanpool(x, k::NTuple{N, Integer}; pad = 0, stride = k) where N
151+
pad = expand(Val(2*N), pad)
152+
stride = expand(Val(N), stride)
153+
pdims = PoolDims(x, k; padding = pad, stride = stride)
154+
return meanpool(x, pdims)
155+
end

0 commit comments

Comments
 (0)