1
+ flipweight (w:: AbstractArray{<:Any,4} ) = w[end : - 1 : 1 ,end : - 1 : 1 ,:,:]
2
+
1
3
function check_support (x, k, pad, stride, dilation = 1 )
2
4
fallback = false
3
5
dilation == 1 || dilation == (1 , 1 ) || (fallback = true )
53
55
54
56
function conv! (y:: A1 , x:: A1 , w:: A1 , b:: A2 ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 ), flipkernel = 0 ) where {A1<: AbstractArray{Float32, 4} , A2<: AbstractArray{Float32, 1} }
55
57
if flipkernel == 0
56
- w = reverse ( reverse (w, dims = 1 ), dims = 2 )
58
+ w = flipweight (w )
57
59
end
58
60
nnp_convolution_output (y, x, w, b, algo = algo, padding = pad, stride = stride)
59
61
end
@@ -74,7 +76,7 @@ function ∇conv_data(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, algo
74
76
end
75
77
76
78
function ∇conv_data! (dx:: A , dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 ), flipkernel = 0 ) where A<: AbstractArray{Float32, 4}
77
- flipkernel == 0 && (w = reverse ( reverse (w, dims = 1 ), dims = 2 ))
79
+ flipkernel == 0 && (w = flipweight (w ))
78
80
nnp_convolution_input_gradient (dx, x, dy, w, padding = pad, stride = stride, algo = algo)
79
81
end
80
82
@@ -88,7 +90,7 @@ function ∇conv_filter(dy::A, x::A, w::A; pad = 0, stride = 1, dilation = 1, al
88
90
end
89
91
90
92
function ∇conv_filter! (dw:: A , dy:: A , x:: A , w:: A ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 ), flipkernel = 0 ) where A<: AbstractArray{Float32, 4}
91
- flipkernel == 0 && (w = reverse ( reverse (w, dims = 1 ), dims = 2 ))
92
- dw . = nnp_convolution_kernel_gradient (dw, x, dy, w, padding = pad, stride = stride, algo = algo)
93
- flipkernel == 0 ? reverse ( reverse (dw, dims = 1 ), dims = 2 ) : dw
93
+ flipkernel == 0 && (w = flipweight (w ))
94
+ nnp_convolution_kernel_gradient (dw, x, dy, w, padding = pad, stride = stride, algo = algo)
95
+ flipkernel == 0 ? flipweight (dw ) : dw
94
96
end
0 commit comments