Skip to content

Commit dd1cb04

Browse files
committed
Flip name to transpose_switchbatch()
Also specialize on kernel size, as that turns out to be helpful for performance.
1 parent 1b4192c commit dd1cb04

File tree

4 files changed

+19
-27
lines changed

4 files changed

+19
-27
lines changed

src/dim_helpers.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@ include("dim_helpers/PoolDims.jl")
66

77

88
"""
9-
transpose_flipbatch(x::AbstractArray)
9+
transpose_swapbatch(x::AbstractArray)
1010
11-
Given an AbstractArray, flip its batch and channel axes, as we must during transposed
11+
Given an AbstractArray, swap its batch and channel axes, as we must during transposed
1212
convolution. We do this to the operands during convolution, and then again to the
1313
output once we're done.
1414
"""
15-
function transpose_flipbatch(x::AbstractArray)
15+
function transpose_swapbatch(x::AbstractArray)
1616
return permutedims(x, ((1:(ndims(x)-2))..., ndims(x), ndims(x)-1))
1717
end
18-
function transpose_flipbatch(x::Tuple)
18+
function transpose_swapbatch(x::Tuple)
1919
return (x[1:end-2]..., x[end], x[end-1])
2020
end
2121

src/dim_helpers/DenseConvDims.jl

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,15 @@ export DenseConvDims
55
66
Concrete subclass of `ConvDims` for a normal, dense, conv2d/conv3d.
77
"""
8-
struct DenseConvDims{N,S,P,D,F} <: ConvDims{N,S,P,D,F}
8+
struct DenseConvDims{N,K,C_in,C_out,S,P,D,F} <: ConvDims{N,S,P,D,F}
99
I::NTuple{N,Int}
10-
K::NTuple{N,Int}
11-
C_in::Int
12-
C_out::Int
1310
end
1411

1512
# Getters for the fields
1613
input_size(c::DenseConvDims) = c.I
17-
kernel_size(c::DenseConvDims) = c.K
18-
channels_in(c::DenseConvDims) = c.C_in
19-
channels_out(c::DenseConvDims) = c.C_out
14+
kernel_size(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F}) where {N,K,C_in,C_out,S,P,D,F} = K
15+
channels_in(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F}) where {N,K,C_in,C_out,S,P,D,F} = C_in
16+
channels_out(c::DenseConvDims{N,K,C_in,C_out,S,P,D,F}) where {N,K,C_in,C_out,S,P,D,F} = C_out
2017

2118
# Convenience wrapper to create DenseConvDims objects
2219
function DenseConvDims(x_size::NTuple{M}, w_size::NTuple{M};
@@ -34,22 +31,16 @@ function DenseConvDims(x_size::NTuple{M}, w_size::NTuple{M};
3431
# The type parameters are what
3532
return DenseConvDims{
3633
M - 2,
34+
w_size[1:end-2],
35+
x_size[end-1],
36+
w_size[end],
3737
stride,
3838
padding,
3939
dilation,
4040
flipkernel
4141
}(
42-
# Image spatial size
42+
# Input spatial size
4343
x_size[1:end-2],
44-
45-
# Kernel spatial size
46-
w_size[1:end-2],
47-
48-
# Input channels
49-
x_size[end-1],
50-
51-
# Output channels
52-
w_size[end],
5344
)
5445
end
5546

@@ -66,7 +57,7 @@ end
6657
function DenseConvDims(c::ConvDims; N=spatial_dims(c), I=input_size(c), K=kernel_size(c),
6758
C_in=channels_in(c), C_out=channels_out(c), S=stride(c),
6859
P=padding(c), D=dilation(c), F=flipkernel(c))
69-
return DenseConvDims{N, S, P, D, F}(I, K, C_in, C_out)
60+
return DenseConvDims{N, K, C_in, C_out, S, P, D, F}(I)
7061
end
7162

7263
function check_dims(x::NTuple{M}, w::NTuple{M}, y::NTuple{M}, cdims::DenseConvDims) where {M}

src/dim_helpers/PoolDims.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,7 @@ struct PoolDims{N,K,S,P,D} <: ConvDims{N, S, P, D, false}
1313
end
1414

1515
# Getters for both type parameters and fields
16-
spatial_dims(c::PoolDims{N,K,S,P,D}) where {N, K, S, P, D} = N
1716
kernel_size(c::PoolDims{N,K,S,P,D}) where {N, K, S, P, D} = K
18-
stride(c::PoolDims{N,K,S,P,D}) where {N, K, S, P, D} = S
19-
padding(c::PoolDims{N,K,S,P,D}) where {N, K, S, P, D} = P
20-
dilation(c::PoolDims{N,K,S,P,D}) where {N, K, S, P, D} = D
2117
input_size(c::PoolDims) = c.I
2218
channels_in(c::PoolDims) = c.C_in
2319
channels_out(c::PoolDims) = c.C_in

src/impl/conv_im2col.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,12 @@ out along the rows of `col`, one for each output pixel. This routine is used by
163163
im2col-based convolutions, just with extra singleton dimensions added in the case of `2d`
164164
or `1d` images.
165165
"""
166-
function im2col!(col::AbstractArray{T,2}, x::AbstractArray{T,4}, cdims::ConvDims) where T
166+
function im2col!(col::AbstractArray{T,2}, x::AbstractArray{T,4},
167+
cdims::ConvDims) where {T}
168+
if spatial_dims(cdims) != 3
169+
throw(DimensionMismatch("im2col!() only accepts 3d convoluitional inputs"))
170+
end
171+
167172
# Extract those nice, compile-time constant type parameters from `cdims`.
168173
width, height, depth = input_size(cdims)
169174
kernel_w, kernel_h, kernel_d = kernel_size(cdims)

0 commit comments

Comments
 (0)