Skip to content

Commit 41d91eb

Browse files
authored
Merge pull request #72 from ayush1999/dev_flipkernel
Expose flipkernel argument
2 parents 5727643 + 55a2b52 commit 41d91eb

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

src/conv.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ function crosscor(x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:Abstra
3636
x, w, pad = pad_, stride = stride_, dilation = dilation)
3737
end
3838

39-
∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
40-
∇conv_data!(zero(x), dy, x, w; pad = pad, stride = stride, dilation = dilation)
39+
∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, flipkernel = 0) where A<:AbstractArray =
40+
∇conv_data!(zero(x), dy, x, w; pad = pad, stride = stride, dilation = dilation, flipkernel=flipkernel)
4141

42-
∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1) where A<:AbstractArray =
43-
∇conv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, dilation = dilation)
42+
∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, flipkernel=0) where A<:AbstractArray =
43+
∇conv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, dilation = dilation, flipkernel=flipkernel)
4444

4545
# N-D dispatch
4646

@@ -58,17 +58,17 @@ end
5858

5959
function ∇conv_filter!(dw::AbstractArray{T,3}, dy::AbstractArray{T,3},
6060
x::AbstractArray{T,3}, w::AbstractArray{T,3};
61-
pad = 0, stride = 1, dilation = 1) where T
61+
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T
6262
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dw, dy, x, w))
63-
∇conv_filter!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1))
63+
∇conv_filter!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation...,1), flipkernel=flipkernel)
6464
return dw
6565
end
6666

6767
function ∇conv_data!(dx::AbstractArray{T,3}, dy::AbstractArray{T,3},
6868
x::AbstractArray{T,3}, w::AbstractArray{T,3};
69-
pad = 0, stride = 1, dilation = 1) where T
69+
pad = 0, stride = 1, dilation = 1, flipkernel = 0) where T
7070
args = map(x -> reshape(x, size(x,1),1,size(x,2),size(x,3)), (dx, dy, x, w))
71-
∇conv_data!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation..., 1))
71+
∇conv_data!(args..., pad = (pad...,0), stride = (stride...,1), dilation = (dilation..., 1), flipkernel = flipkernel)
7272
return dx
7373
end
7474

@@ -77,24 +77,24 @@ conv!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
7777
conv2d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
7878

7979
∇conv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
80-
pad = 0, stride = 1, dilation = 1) where T =
81-
conv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)
80+
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
81+
conv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
8282

8383
∇conv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
84-
pad = 0, stride = 1, dilation = 1) where T =
85-
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
84+
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
85+
conv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
8686

8787
conv!(y::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
8888
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
8989
conv3d!(y, x, w, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
9090

9191
∇conv_filter!(dw::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
92-
pad = 0, stride = 1, dilation = 1) where T =
93-
conv3d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation)
92+
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
93+
conv3d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
9494

9595
∇conv_data!(dx::AbstractArray{T,5}, dy::AbstractArray{T,5}, x::AbstractArray{T,5}, w::AbstractArray{T,5};
96-
pad = 0, stride = 1, dilation = 1) where T =
97-
conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation)
96+
pad = 0, stride = 1, dilation = 1, flipkernel=0) where T =
97+
conv3d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, dilation = dilation, mode=flipkernel)
9898

9999
# Depthwise Conv
100100

@@ -120,19 +120,19 @@ depthwisecrosscor!(y::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArra
120120
pad = 0, stride = 1) where T =
121121
depthwiseconv!(y, x, w, pad = pad, stride = stride, flipkernel=1)
122122

123-
∇depthwiseconv_data(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
124-
∇depthwiseconv_data!(zero(x), dy, x, w; pad = pad, stride = stride)
123+
∇depthwiseconv_data(dy::A, x::A, w::A; pad = 0, stride = 1, flipkernel=0) where A<:AbstractArray =
124+
∇depthwiseconv_data!(zero(x), dy, x, w; pad = pad, stride = stride, flipkernel=flipkernel)
125125

126-
∇depthwiseconv_filter(dy::A, x::A, w::A; pad = 0, stride = 1) where A<:AbstractArray =
127-
∇depthwiseconv_filter!(zero(w), dy, x, w; pad = pad, stride = stride)
126+
∇depthwiseconv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, flipkernel=0) where A<:AbstractArray =
127+
∇depthwiseconv_filter!(zero(w), dy, x, w; pad = pad, stride = stride, flipkernel=flipkernel)
128128

129129
∇depthwiseconv_filter!(dw::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
130-
pad = 0, stride = 1) where T =
131-
depthwiseconv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride)
130+
pad = 0, stride = 1, flipkernel=0) where T =
131+
depthwiseconv2d_grad_w!(dw, x, w, dy, padding = pad, stride = stride, mode=flipkernel)
132132

133133
∇depthwiseconv_data!(dx::AbstractArray{T,4}, dy::AbstractArray{T,4}, x::AbstractArray{T,4}, w::AbstractArray{T,4};
134-
pad = 0, stride = 1) where T =
135-
depthwiseconv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride)
134+
pad = 0, stride = 1, flipkernel=0) where T =
135+
depthwiseconv2d_grad_x!(dx, x, w, dy, padding = pad, stride = stride, mode=flipkernel)
136136

137137
# Pooling
138138

0 commit comments

Comments
 (0)