|
1 |
| -export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter! |
| 1 | +export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!, depthwiseconv, |
| 2 | + depthwiseconv!, ∇depthwiseconv_data, ∇depthwiseconv_data!, ∇depthwiseconv_filter, |
| 3 | + ∇depthwiseconv_filter! |
2 | 4 |
|
3 | 5 | ## Convolution API
|
4 | 6 | #
|
@@ -161,3 +163,19 @@ if is_nnpack_available()
|
161 | 163 | return conv_nnpack(x, w, cdims; kwargs...)
|
162 | 164 | end
|
163 | 165 | end
|
| 166 | + |
| 167 | +function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1) where {T, N} |
| 168 | + stride = expand(Val(N-2), stride) |
| 169 | + pad = expand(Val(N-2), pad) |
| 170 | + dilation = expand(Val(N-2), dilation) |
| 171 | + cdims = DenseConvDims(x, w; stride = stride, padding = pad, dilation = dilation) |
| 172 | + return conv(x, w, cdims) |
| 173 | +end |
| 174 | + |
| 175 | +function depthwiseconv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1) where {T, N} |
| 176 | + stride = expand(Val(N-2), stride) |
| 177 | + pad = expand(Val(N-2), pad) |
| 178 | + dilation = expand(Val(N-2), dilation) |
| 179 | + cdims = DepthwiseConvDims(x, w; stride = stride, padding = pad, dilation = dilation) |
| 180 | + return depthwiseconv(x, w, cdims) |
| 181 | +end |
0 commit comments