@@ -6,9 +6,9 @@ _paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]
6
6
expand (N, i:: Tuple ) = i
7
7
expand (N, i:: Integer ) = ntuple (_ -> i, N)
8
8
9
- conv_reshape_bias (c) = c. bias isa AbstractVector ?
10
- reshape (c . bias, map (_ -> 1 , c . stride) ... , :, 1 ) :
11
- c . bias
9
+ conv_reshape_bias (c) = conv_reshape_bias ( c. bias, c . stride)
10
+ conv_reshape_bias ( @nospecialize ( bias), _) = bias
11
+ conv_reshape_bias (bias :: AbstractVector , stride) = reshape ( bias, map (_ -> 1 , stride) ... , :, 1 )
12
12
13
13
"""
14
14
SamePad()
164
164
165
165
@functor Conv
166
166
167
+ conv_dims (c:: Conv , x:: AbstractArray ) =
168
+ DenseConvDims (x, c. weight; stride = c. stride, padding = c. pad, dilation = c. dilation, groups = c. groups)
169
+
170
+ ChainRulesCore. @non_differentiable conv_dims (:: Any , :: Any )
171
+
167
172
function (c:: Conv )(x:: AbstractArray )
168
173
σ = NNlib. fast_act (c. σ, x)
169
- cdims = DenseConvDims (x, c . weight; stride = c . stride, padding = c . pad, dilation = c . dilation, groups = c . groups )
174
+ cdims = conv_dims (c, x )
170
175
σ .(conv (x, c. weight, cdims) .+ conv_reshape_bias (c))
171
176
end
172
177
@@ -400,9 +405,14 @@ function crosscor(x, w, ddims::DenseConvDims)
400
405
return conv (x, w, ddims)
401
406
end
402
407
408
+ crosscor_dims (c:: CrossCor , x:: AbstractArray ) =
409
+ DenseConvDims (x, c. weight; stride = c. stride, padding = c. pad, dilation = c. dilation)
410
+
411
+ ChainRulesCore. @non_differentiable crosscor_dims (:: Any , :: Any )
412
+
403
413
function (c:: CrossCor )(x:: AbstractArray )
404
414
σ = NNlib. fast_act (c. σ, x)
405
- cdims = DenseConvDims (x, c . weight; stride = c . stride, padding = c . pad, dilation = c . dilation )
415
+ cdims = crosscor_dims (c, x )
406
416
σ .(crosscor (x, c. weight, cdims) .+ conv_reshape_bias (c))
407
417
end
408
418
0 commit comments