@@ -91,7 +91,7 @@ function ∇conv_data(dy::AA{4}, x::AA{4}, w::AA{4}; pad = 0, stride = 1, dilati
91
91
dilation == 1 || dilation == (1 , 1 ) || error (" NNPACK does not support dilation > 1" )
92
92
pad_, stride_ = padtuple (x, pad), padtuple (x, stride)
93
93
((size (x, 1 ) - size (w, 1 ) + 2 * pad_[1 ]) % stride_[1 ] == 0 && (size (x, 2 ) - size (w, 2 ) + 2 * pad_[2 ]) % stride_[2 ] == 0 ) || error (" Choose the stride, pad and kernel size properly" )
94
- ∇conv_data! (zeros (Float32, size (x)), dy, x, w; pad = pad , stride = stride , dilation = dilation, algo = UInt32 (algo))
94
+ ∇conv_data! (zeros (Float32, size (x)), dy, x, w; pad = pad_ , stride = stride_ , dilation = dilation, algo = UInt32 (algo))
95
95
end
96
96
97
97
∇conv_data! (dx:: AbstractArray{Float64, 4} , dy:: AbstractArray{Float64, 4} , x:: AbstractArray{Float64, 4} , w:: AbstractArray{Float64, 4} ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 ), flipkernel = 0 ) =
@@ -109,7 +109,7 @@ function ∇conv_filter(dy::AA{4}, x::AA{4}, w::AA{4}; pad = 0, stride = 1, dila
109
109
dilation == 1 || dilation == (1 , 1 ) || error (" NNPACK does not support dilation > 1" )
110
110
pad_, stride_ = padtuple (x, pad), padtuple (x, stride)
111
111
((size (x, 1 ) - size (w, 1 ) + 2 * pad_[1 ]) % stride_[1 ] == 0 && (size (x, 2 ) - size (w, 2 ) + 2 * pad_[2 ]) % stride_[2 ] == 0 ) || error (" Choose the stride, pad and kernel size properly" )
112
- ∇conv_filter! (zeros (Float32, size (w)), dy, x, w; pad = pad , stride = stride , dilation = dilation, algo = UInt32 (algo))
112
+ ∇conv_filter! (zeros (Float32, size (w)), dy, x, w; pad = pad_ , stride = stride_ , dilation = dilation, algo = UInt32 (algo))
113
113
end
114
114
115
115
∇conv_filter! (dw:: AbstractArray{Float64, 4} , dy:: AbstractArray{Float64, 4} , x:: AbstractArray{Float64, 4} , w:: AbstractArray{Float64, 4} ; pad = 0 , stride = 1 , dilation = 1 , algo = UInt32 (0 ), flipkernel = 0 ) =
0 commit comments