From a6fb26702a51984a7d617b8379c8067d8e647ca5 Mon Sep 17 00:00:00 2001 From: karthik katipally Date: Sat, 30 Nov 2019 08:09:27 +0530 Subject: [PATCH 1/2] Added groupwiseconv and modified depthwise conv for common interface --- src/Flux.jl | 2 +- src/layers/conv.jl | 112 ++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 102 insertions(+), 12 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index 9969b32346..8467174264 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -10,7 +10,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd export gradient export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, - DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, + DepthwiseConv, GroupwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, SkipConnection, params, fmap, cpu, gpu, f32, f64 include("optimise/Optimise.jl") diff --git a/src/layers/conv.jl b/src/layers/conv.jl index f4de3ffcf2..ea32c08438 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,4 +1,4 @@ -using NNlib: conv, ∇conv_data, depthwiseconv +using NNlib: conv, ∇conv_data expand(N, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) @@ -51,6 +51,7 @@ function (c::Conv)(x::AbstractArray) # TODO: breaks gpu broadcast :( # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) + @show b cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation) σ.(conv(x, c.weight, cdims) .+ b) end @@ -160,41 +161,52 @@ struct DepthwiseConv{N,M,F,A,V} stride::NTuple{N,Int} pad::NTuple{M,Int} dilation::NTuple{N,Int} + groupcount::Int end +# TODO groupcount should be inferred. function DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; - stride = 1, pad = 0, dilation = 1) where {T,N} + stride = 1, pad = 0, dilation = 1, groupcount = 1) where {T,N} stride = expand(Val(N-2), stride) pad = expand(Val(2*(N-2)), pad) dilation = expand(Val(N-2), dilation) - return DepthwiseConv(σ, w, b, stride, pad, dilation) + return DepthwiseConv(σ, w, b, stride, pad, dilation, groupcount) end function DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; - init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N - @assert ch[2] % ch[1] == 0 "Output channels must be integer multiple of input channels" + init = glorot_uniform, stride = 1, pad = 0, dilation = 1, groupcount=1) where N + @assert ch[2] % groupcount == 0 "Output channels must be integer multiple of input channels" + @assert ch[1] % groupcount == 0 "Input channels must be interger multiples of groupcount" return DepthwiseConv( - init(k..., div(ch[2], ch[1]), ch[1]), + init(k..., div(ch[1], groupcount), ch[2]), zeros(ch[2]), σ; stride = stride, pad = pad, - dilation = dilation + dilation = dilation, + groupcount = groupcount ) end @functor DepthwiseConv +# TODO may not necessary +function depthwiseconv(x, w, ddims::DenseConvDims) + ddims = DenseConvDims(ddims) + return conv(x, w, ddims) +end + function (c::DepthwiseConv)(x) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) - cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation) - σ.(depthwiseconv(x, c.weight, cdims) .+ b) + cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groupcount=c.groupcount) + σ.(conv(x, c.weight, cdims) .+ b) end function Base.show(io::IO, l::DepthwiseConv) - print(io, "DepthwiseConv(", size(l.weight)[1:end-2]) - print(io, ", ", size(l.weight)[end], "=>", prod(size(l.weight)[end-1:end])) + print(io, "DepthwiseConv(", size(l.weight, ndims(l.weight)-2)) + print(io, ", ", size(l.weight, ndims(l.weight)-1)*l.groupcount, "=>", size(l.weight, ndims(l.weight))) l.σ == identity || print(io, ", ", l.σ) + l.groupcount == 1 || print(io, ", groupcount = ", l.groupcount) print(io, ")") end @@ -204,6 +216,84 @@ end (a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = a(T.(x)) + +""" + GroupwiseConv(size, in=>out) + GroupwiseConv(size, in=>out, relu) + +Groupwise convolutional layer. `size` should be a tuple like `(2, 2)`. +`in` and `out` specify the number of input and output channels respectively. +Note that `out` must be an integer multiple of `in`. + +Data should be stored in WHCN order. In other words, a 100×100 RGB image would +be a `100×100×3` array, and a batch of 50 would be a `100×100×3×50` array. + +Takes the keyword arguments `pad`, `stride` and `dilation`. +""" +struct GroupwiseConv{N,M,F,A,V} + σ::F + weight::A + bias::V + stride::NTuple{N,Int} + pad::NTuple{M,Int} + dilation::NTuple{N,Int} + groupcount::Int +end + +# TODO groupcount should be mandatory +function GroupwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity; + stride = 1, pad = 0, dilation = 1, groupcount = 1) where {T,N} + stride = expand(Val(N-2), stride) + pad = expand(Val(2*(N-2)), pad) + dilation = expand(Val(N-2), dilation) + return GroupwiseConv(σ, w, b, stride, pad, dilation, groupcount) +end + +function GroupwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; + init = glorot_uniform, stride = 1, pad = 0, dilation = 1, groupcount=1) where N + @assert ch[2] % groupcount == 0 "Output channels must be integer multiple of input channels" + @assert ch[1] % groupcount == 0 "Input channels must be interger multiples of groupcount" + return GroupwiseConv( + init(k..., div(ch[1], groupcount), ch[2]), + zeros(ch[2]), + σ; + stride = stride, + pad = pad, + dilation = dilation, + groupcount = groupcount + ) +end + +@functor GroupwiseConv + +# TODO may not necessary +function groupwiseconv(x, w, ddims::DenseConvDims) + ddims = DenseConvDims(ddims) + return conv(x, w, ddims) +end + +function (c::GroupwiseConv)(x) + σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) + @info b, c.bias + cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groupcount=c.groupcount) + σ.(conv(x, c.weight, cdims) .+ b) +end + +function Base.show(io::IO, l::GroupwiseConv) + print(io, "GroupwiseConv(", size(l.weight, ndims(l.weight)-2)) + print(io, ", ", size(l.weight, ndims(l.weight)-1)*l.groupcount, "=>", size(l.weight, ndims(l.weight))) + l.σ == identity || print(io, ", ", l.σ) + l.groupcount == 1 || print(io, ", groupcount = ", l.groupcount) + print(io, ")") +end + +(a::GroupwiseConv{<:Any,<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + invoke(a, Tuple{AbstractArray}, x) + +(a::GroupwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = + a(T.(x)) + + """ CrossCor(size, in=>out) CrossCor(size, in=>out, relu) From 32fadfa068eeb474a1e846bb198acc13e3d55bc3 Mon Sep 17 00:00:00 2001 From: karthik katipally Date: Sun, 8 Dec 2019 12:30:10 +0530 Subject: [PATCH 2/2] using only ConvDims for all convolutions. --- src/layers/conv.jl | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index ea32c08438..40ba7a7859 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,4 +1,4 @@ -using NNlib: conv, ∇conv_data +using NNlib: conv, ∇conv_data, group_count, channels_in expand(N, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) @@ -51,8 +51,8 @@ function (c::Conv)(x::AbstractArray) # TODO: breaks gpu broadcast :( # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) - @show b - cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation) + cdims = ConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groupcount=1) + # @assert group_count(cdims) == 1 DimensionMismatch("Group count is expected to be 1; (1) vs. $(group_count(cdims)))") σ.(conv(x, c.weight, cdims) .+ b) end @@ -111,8 +111,8 @@ function conv_transpose_dims(c::ConvTranspose, x::AbstractArray) I = (size(x)[1:end-2] .- 1).*c.stride .+ 1 .+ (size(c.weight)[1:end-2] .- 1).*c.dilation .- combined_pad C_in = size(c.weight)[end-1] batch_size = size(x)[end] - # Create DenseConvDims() that looks like the corresponding conv() - return DenseConvDims((I..., C_in, batch_size), size(c.weight); + # Create ConvDims() that looks like the corresponding conv() + return ConvDims((I..., C_in, batch_size), size(c.weight); stride=c.stride, padding=c.pad, dilation=c.dilation, @@ -191,14 +191,16 @@ end @functor DepthwiseConv # TODO may not necessary -function depthwiseconv(x, w, ddims::DenseConvDims) - ddims = DenseConvDims(ddims) +function depthwiseconv(x, w, ddims::ConvDims) + # @assert x[end-1] == channels_in(cdims) DimensionMismatch("Data input channel count ($(x[M-1]) vs. $(channels_in(cdims)))") + ddims = ConvDims(ddims) return conv(x, w, ddims) end function (c::DepthwiseConv)(x) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) - cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groupcount=c.groupcount) + cdims = ConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groupcount=c.groupcount) + @assert group_count(cdims) == channels_in(cdims) DimensionMismatch("Data input channel count ≠ group count ($(group_count(cdims)) ≠ $(channels_in(cdims)))") σ.(conv(x, c.weight, cdims) .+ b) end @@ -267,15 +269,14 @@ end @functor GroupwiseConv # TODO may not necessary -function groupwiseconv(x, w, ddims::DenseConvDims) - ddims = DenseConvDims(ddims) +function groupwiseconv(x, w, ddims::ConvDims) + ddims = ConvDims(ddims) return conv(x, w, ddims) end function (c::GroupwiseConv)(x) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) - @info b, c.bias - cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groupcount=c.groupcount) + cdims = ConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation, groupcount=c.groupcount) σ.(conv(x, c.weight, cdims) .+ b) end @@ -339,8 +340,8 @@ CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; @functor CrossCor -function crosscor(x, w, ddims::DenseConvDims) - ddims = DenseConvDims(ddims, F=true) +function crosscor(x, w, ddims::ConvDims) + ddims = ConvDims(ddims, F=true) return conv(x, w, ddims) end @@ -348,7 +349,7 @@ function (c::CrossCor)(x::AbstractArray) # TODO: breaks gpu broadcast :( # ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1))) σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1) - cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation) + cdims = ConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation) σ.(crosscor(x, c.weight, cdims) .+ b) end