@@ -114,7 +114,7 @@ for backend in (Symbol(), :_direct, :_im2col)
114
114
fill! (y, xT (0 ))
115
115
return $ (Symbol (" $(name)$(backend) !" ))(y, x, pdims; kwargs... )
116
116
end
117
-
117
+
118
118
# Backprops too
119
119
@timeit_debug to function $ (Symbol (" ∇$(name)$(backend) " ))(
120
120
dy:: AbstractArray{T,N} , y:: AbstractArray{T,N} ,
@@ -136,3 +136,20 @@ if is_nnpack_available()
136
136
return func (x, pdims; kwargs... )
137
137
end
138
138
end
139
+
140
+ expand (N, i:: Tuple ) = i
141
+ expand (N, i:: Integer ) = ntuple (_ -> i, N)
142
+
143
+ function maxpool (x, k:: NTuple{N, Integer} ; pad = 0 , stride = k) where N
144
+ pad = expand (Val (2 * N), pad)
145
+ stride = expand (Val (N), stride)
146
+ pdims = PoolDims (x, k; padding = pad, stride = stride)
147
+ return maxpool (x, pdims)
148
+ end
149
+
150
+ function meanpool (x, k:: NTuple{N, Integer} ; pad = 0 , stride = k) where N
151
+ pad = expand (Val (2 * N), pad)
152
+ stride = expand (Val (N), stride)
153
+ pdims = PoolDims (x, k; padding = pad, stride = stride)
154
+ return meanpool (x, pdims)
155
+ end
0 commit comments