Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient

export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
DepthwiseConv, GroupwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
SkipConnection, params, fmap, cpu, gpu, f32, f64

include("optimise/Optimise.jl")
Expand Down
125 changes: 108 additions & 17 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using NNlib: conv, ∇conv_data, depthwiseconv
using NNlib: conv, ∇conv_data, group_count, channels_in

expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)
Expand Down Expand Up @@ -51,7 +51,8 @@ function (c::Conv)(x::AbstractArray)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
cdims = ConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groupcount=1)
# @assert group_count(cdims) == 1 DimensionMismatch("Group count is expected to be 1; (1) vs. $(group_count(cdims)))")
σ.(conv(x, c.weight, cdims) .+ b)
end

Expand Down Expand Up @@ -110,8 +111,8 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad
C_in = size(c.weight)[end-1]
batch_size = size(x)[end]
# Create DenseConvDims() that looks like the corresponding conv()
return DenseConvDims((I..., C_in, batch_size), size(c.weight);
# Create ConvDims() that looks like the corresponding conv()
return ConvDims((I..., C_in, batch_size), size(c.weight);
stride=c.stride,
padding=c.pad,
dilation=c.dilation,
Expand Down Expand Up @@ -160,41 +161,54 @@ struct DepthwiseConv{N,M,F,A,V}
stride::NTuple{N,Int}
pad::NTuple{M,Int}
dilation::NTuple{N,Int}
groupcount::Int
end

# TODO groupcount should be inferred.
function DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = 1, pad = 0, dilation = 1, groupcount = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
return DepthwiseConv(σ, w, b, stride, pad, dilation)
return DepthwiseConv(σ, w, b, stride, pad, dilation, groupcount)
end

function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N
@assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels"
init = glorot_uniform, stride = 1, pad = 0, dilation = 1, groupcount=1) where N
@assert ch[2] % groupcount == 0 "Output channels must be integer multiple of input channels"
@assert ch[1] % groupcount == 0 "Input channels must be interger multiples of groupcount"
return DepthwiseConv(
init(k..., div(ch[2], ch[1]), ch[1]),
init(k..., div(ch[1], groupcount), ch[2]),
zeros(ch[2]),
σ;
stride = stride,
pad = pad,
dilation = dilation
dilation = dilation,
groupcount = groupcount
)
end

@functor DepthwiseConv

# TODO may not necessary
function depthwiseconv(x, w, ddims::ConvDims)
# @assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))")
ddims = ConvDims(ddims)
return conv(x, w, ddims)
end

function (c::DepthwiseConv)(x)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(depthwiseconv(x, c.weight, cdims) .+ b)
cdims = ConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groupcount=c.groupcount)
@assert group_count(cdims) == channels_in(cdims) DimensionMismatch("Data input channel count ≠ group count ($(group_count(cdims)) ≠ $(channels_in(cdims)))")
σ.(conv(x, c.weight, cdims) .+ b)
end

function Base.show(io::IO, l::DepthwiseConv)
print(io, "DepthwiseConv(", size(l.weight)[1:end-2])
print(io, ", ", size(l.weight)[end], "=>", prod(size(l.weight)[end-1:end]))
print(io, "DepthwiseConv(", size(l.weight, ndims(l.weight)-2))
print(io, ", ", size(l.weight, ndims(l.weight)-1)*l.groupcount, "=>", size(l.weight, ndims(l.weight)))
l.σ == identity || print(io, ", ", l.σ)
l.groupcount == 1 || print(io, ", groupcount = ", l.groupcount)
print(io, ")")
end

Expand All @@ -204,6 +218,83 @@ end
(a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))


"""
GroupwiseConv(size, in=>out)
GroupwiseConv(size, in=>out, relu)

Groupwise convolutional layer. `size` should be a tuple like `(2, 2)`.
`in` and `out` specify the number of input and output channels respectively.
Note that `out` must be an integer multiple of `in`.

Data should be stored in WHCN order. In other words, a 100×100 RGB image would
be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array.

Takes the keyword arguments `pad`, `stride` and `dilation`.
"""
struct GroupwiseConv{N,M,F,A,V}
σ::F
weight::A
bias::V
stride::NTuple{N,Int}
pad::NTuple{M,Int}
dilation::NTuple{N,Int}
groupcount::Int
end

# TODO groupcount should be mandatory
function GroupwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0, dilation = 1, groupcount = 1) where {T,N}
stride = expand(Val(N-2), stride)
pad = expand(Val(2*(N-2)), pad)
dilation = expand(Val(N-2), dilation)
return GroupwiseConv(σ, w, b, stride, pad, dilation, groupcount)
end

function GroupwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
init = glorot_uniform, stride = 1, pad = 0, dilation = 1, groupcount=1) where N
@assert ch[2] % groupcount == 0 "Output channels must be integer multiple of input channels"
@assert ch[1] % groupcount == 0 "Input channels must be interger multiples of groupcount"
return GroupwiseConv(
init(k..., div(ch[1], groupcount), ch[2]),
zeros(ch[2]),
σ;
stride = stride,
pad = pad,
dilation = dilation,
groupcount = groupcount
)
end

@functor GroupwiseConv

# TODO may not necessary
function groupwiseconv(x, w, ddims::ConvDims)
ddims = ConvDims(ddims)
return conv(x, w, ddims)
end

function (c::GroupwiseConv)(x)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
cdims = ConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groupcount=c.groupcount)
σ.(conv(x, c.weight, cdims) .+ b)
end

function Base.show(io::IO, l::GroupwiseConv)
print(io, "GroupwiseConv(", size(l.weight, ndims(l.weight)-2))
print(io, ", ", size(l.weight, ndims(l.weight)-1)*l.groupcount, "=>", size(l.weight, ndims(l.weight)))
l.σ == identity || print(io, ", ", l.σ)
l.groupcount == 1 || print(io, ", groupcount = ", l.groupcount)
print(io, ")")
end

(a::GroupwiseConv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)

(a::GroupwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))


"""
CrossCor(size, in=>out)
CrossCor(size, in=>out, relu)
Expand Down Expand Up @@ -249,16 +340,16 @@ CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;

@functor CrossCor

function crosscor(x, w, ddims::DenseConvDims)
ddims = DenseConvDims(ddims, F=true)
function crosscor(x, w, ddims::ConvDims)
ddims = ConvDims(ddims, F=true)
return conv(x, w, ddims)
end

function (c::CrossCor)(x::AbstractArray)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
cdims = ConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(crosscor(x, c.weight, cdims) .+ b)
end

Expand Down