Skip to content

Commit 946b815

Browse files
authored
Merge pull request #1956 from FluxML/bc/conv-helper-tweaks
Type stable `conv_reshape_bias` and AD-friendly `ConvDims` helpers
2 parents 58785f3 + 66c2ec6 commit 946b815

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

src/layers/conv.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ _paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]
66
expand(N, i::Tuple) = i
77
expand(N, i::Integer) = ntuple(_ -> i, N)
88

9-
conv_reshape_bias(c) = c.bias isa AbstractVector ?
10-
reshape(c.bias, map(_->1, c.stride)..., :, 1) :
11-
c.bias
9+
conv_reshape_bias(c) = conv_reshape_bias(c.bias, c.stride)
10+
conv_reshape_bias(@nospecialize(bias), _) = bias
11+
conv_reshape_bias(bias::AbstractVector, stride) = reshape(bias, map(_->1, stride)..., :, 1)
1212

1313
"""
1414
SamePad()
@@ -164,9 +164,14 @@ end
164164

165165
@functor Conv
166166

167+
conv_dims(c::Conv, x::AbstractArray) =
168+
DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)
169+
170+
ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)
171+
167172
function (c::Conv)(x::AbstractArray)
168173
σ = NNlib.fast_act(c.σ, x)
169-
cdims = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)
174+
cdims = conv_dims(c, x)
170175
σ.(conv(x, c.weight, cdims) .+ conv_reshape_bias(c))
171176
end
172177

@@ -400,9 +405,14 @@ function crosscor(x, w, ddims::DenseConvDims)
400405
return conv(x, w, ddims)
401406
end
402407

408+
crosscor_dims(c::CrossCor, x::AbstractArray) =
409+
DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation)
410+
411+
ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)
412+
403413
function (c::CrossCor)(x::AbstractArray)
404414
σ = NNlib.fast_act(c.σ, x)
405-
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
415+
cdims = crosscor_dims(c, x)
406416
σ.(crosscor(x, c.weight, cdims) .+ conv_reshape_bias(c))
407417
end
408418

0 commit comments

Comments
 (0)