Skip to content

Commit efb32ce

Browse files
committed
pull out flipweight
1 parent 1fa88c0 commit efb32ce

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/nnpack/interface.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
flipweight(w::AbstractArray{<:Any,4}) = w[end:-1:1,end:-1:1,:,:]
2+
13
function check_support(x, k, pad, stride, dilation = 1)
24
fallback = false
35
dilation == 1 || dilation == (1, 1) || (fallback = true)
@@ -53,7 +55,7 @@ end
5355

5456
function conv!(y::A1, x::A1, w::A1, b::A2; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where {A1<:AbstractArray{Float32, 4}, A2<:AbstractArray{Float32, 1}}
5557
if flipkernel == 0
56-
w = reverse(reverse(w, dims=1), dims=2)
58+
w = flipweight(w)
5759
end
5860
nnp_convolution_output(y, x, w, b, algo = algo, padding = pad, stride = stride)
5961
end
@@ -74,7 +76,7 @@ function ∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo
7476
end
7577

7678
function ∇conv_data!(dx::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float32, 4}
77-
flipkernel == 0 && (w = reverse(reverse(w, dims=1), dims=2))
79+
flipkernel == 0 && (w = flipweight(w))
7880
nnp_convolution_input_gradient(dx, x, dy, w, padding = pad, stride = stride, algo = algo)
7981
end
8082

@@ -88,7 +90,7 @@ function ∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, al
8890
end
8991

9092
function ∇conv_filter!(dw::A, dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo = UInt32(0), flipkernel = 0) where A<:AbstractArray{Float32, 4}
91-
flipkernel == 0 && (w = reverse(reverse(w, dims=1), dims=2))
92-
dw .= nnp_convolution_kernel_gradient(dw, x, dy, w, padding = pad, stride = stride, algo = algo)
93-
flipkernel == 0 ? reverse(reverse(dw, dims=1), dims=2) : dw
93+
flipkernel == 0 && (w = flipweight(w))
94+
nnp_convolution_kernel_gradient(dw, x, dy, w, padding = pad, stride = stride, algo = algo)
95+
flipkernel == 0 ? flipweight(dw) : dw
9496
end

0 commit comments

Comments
 (0)