Skip to content

Commit df48e61

Browse files
committed
Wrappers for conv and depthwiseconv
1 parent f5fce7a commit df48e61

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

src/conv.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!
1+
export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!, depthwiseconv,
2+
depthwiseconv!, ∇depthwiseconv_data, ∇depthwiseconv_data!, ∇depthwiseconv_filter,
3+
∇depthwiseconv_filter!
24

35
## Convolution API
46
#
@@ -161,3 +163,19 @@ if is_nnpack_available()
161163
return conv_nnpack(x, w, cdims; kwargs...)
162164
end
163165
end
166+
167+
function conv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1) where {T, N}
168+
stride = expand(Val(N-2), stride)
169+
pad = expand(Val(N-2), pad)
170+
dilation = expand(Val(N-2), dilation)
171+
cdims = DenseConvDims(x, w; stride = stride, padding = pad, dilation = dilation)
172+
return conv(x, w, cdims)
173+
end
174+
175+
function depthwiseconv(x, w::AbstractArray{T, N}; stride = 1, pad = 0, dilation = 1) where {T, N}
176+
stride = expand(Val(N-2), stride)
177+
pad = expand(Val(N-2), pad)
178+
dilation = expand(Val(N-2), dilation)
179+
cdims = DepthwiseConvDims(x, w; stride = stride, padding = pad, dilation = dilation)
180+
return depthwiseconv(x, w, cdims)
181+
end

0 commit comments

Comments
 (0)