Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalise.jl")
include("layers/upsample.jl")
include("layers/show.jl")

include("outputsize.jl")

Expand Down
3 changes: 1 addition & 2 deletions src/functor.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import Adapt: adapt, adapt_storage
using LinearAlgebra: Cholesky
using Zygote: IdSet
import Functors: @functor, functor, fmap
import Functors
import Functors: Functors, @functor, functor, fmap, isleaf

trainable(m) = functor(m)[1]

Expand Down
4 changes: 3 additions & 1 deletion src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,10 @@ Parallel(connection, layers...) = Parallel(connection, layers)
Base.getindex(m::Parallel, i::Integer) = m.layers[i]
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...)

trainable(m::Parallel) = (m.connection, m.layers...)

function Base.show(io::IO, m::Parallel)
print(io, "Parallel(", m.connection, ", ")
join(io, m.layers, ", ")
print(io, ")")
end
end
71 changes: 41 additions & 30 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ See also [`ConvTranspose`](@ref), [`DepthwiseConv`](@ref), [`CrossCor`](@ref).
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of images

julia> lay = Conv((5,5), 3 => 7, relu; bias=false)
Conv((5, 5), 3=>7, relu)
Conv((5, 5), 3 => 7, relu) # 525 parameters

julia> lay(xs) |> size
(96, 96, 7, 50)
Expand Down Expand Up @@ -98,7 +98,7 @@ end
Conv(weight::AbstractArray, [bias, activation; stride, pad, dilation])

Constructs a convolutional layer with the given weight and bias.
Accepts the same keywords (and has the same defaults) as the `Conv((4,4), 3=>7, relu)`
Accepts the same keywords (and has the same defaults) as the `Conv((4,4), 3 => 7, relu)`
method.

# Examples
Expand All @@ -108,7 +108,7 @@ julia> weight = rand(3, 4, 5);
julia> bias = zeros(5);

julia> c1 = Conv(weight, bias, sigmoid) # expects 1 spatial dimension
Conv((3,), 4=>5, σ)
Conv((3,), 4 => 5, σ) # 65 parameters

julia> c1(randn(100, 4, 64)) |> size
(98, 5, 64)
Expand All @@ -134,7 +134,7 @@ function Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity
end

"""
convfilter(filter::Tuple, in=>out)
convfilter(filter::Tuple, in => out)

Constructs a standard convolutional weight matrix with given `filter` and
channels from `in` to `out`.
Expand All @@ -159,11 +159,18 @@ end

function Base.show(io::IO, l::Conv)
print(io, "Conv(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight)))
l.σ == identity || print(io, ", ", l.σ)
print(io, ", ", size(l.weight, ndims(l.weight)-1), " => ", size(l.weight, ndims(l.weight)))
_print_conv_opt(io, l)
print(io, ")")
end

function _print_conv_opt(io::IO, l)
l.σ == identity || print(io, ", ", l.σ)
all(==(0), l.pad) || print(io, ", pad=", _maybetuple_string(l.pad))
all(==(1), l.stride) || print(io, ", stride=", _maybetuple_string(l.stride))
all(==(1), l.dilation) || print(io, ", dilation=", _maybetuple_string(l.dilation))
l.bias == Zeros() && print(io, ", bias=false")
end

"""
ConvTranspose(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init])
Expand All @@ -184,15 +191,15 @@ See also [`Conv`](@ref) for more detailed description of keywords.
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images

julia> lay = ConvTranspose((5,5), 3 => 7, relu)
ConvTranspose((5, 5), 3=>7, relu)
ConvTranspose((5, 5), 3 => 7, relu) # 532 parameters

julia> lay(xs) |> size
(104, 104, 7, 50)

julia> ConvTranspose((5,5), 3=>7, stride=2)(xs) |> size
julia> ConvTranspose((5,5), 3 => 7, stride=2)(xs) |> size
(203, 203, 7, 50)

julia> ConvTranspose((5,5), 3=>7, stride=3, pad=SamePad())(xs) |> size
julia> ConvTranspose((5,5), 3 => 7, stride=3, pad=SamePad())(xs) |> size
(300, 300, 7, 50)
```
"""
Expand All @@ -209,7 +216,7 @@ end
ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, dilation])

Constructs a layer with the given weight and bias arrays.
Accepts the same keywords as the `ConvTranspose((4,4), 3=>7, relu)` method.
Accepts the same keywords as the `ConvTranspose((4,4), 3 => 7, relu)` method.
"""
function ConvTranspose(w::AbstractArray{T,N}, bias = true, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
Expand Down Expand Up @@ -255,8 +262,8 @@ end

function Base.show(io::IO, l::ConvTranspose)
print(io, "ConvTranspose(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)), "=>", size(l.weight, ndims(l.weight)-1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ", ", size(l.weight, ndims(l.weight)), " => ", size(l.weight, ndims(l.weight)-1))
_print_conv_opt(io, l)
print(io, ")")
end

Expand All @@ -266,7 +273,7 @@ function calc_padding(::Type{ConvTranspose}, pad::SamePad, k::NTuple{N,T}, dilat
end

"""
DepthwiseConv(filter, in=>out, σ=identity; stride=1, pad=0, dilation=1, [bias, init])
DepthwiseConv(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init])

Depthwise convolutional layer. `filter` is a tuple of integers
specifying the size of the convolutional kernel, while
Expand All @@ -284,7 +291,7 @@ See also [`Conv`](@ref) for more detailed description of keywords.
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images

julia> lay = DepthwiseConv((5,5), 3 => 6, relu; bias=false)
DepthwiseConv((5, 5), 3=>6, relu)
DepthwiseConv((5, 5), 3 => 6, relu) # 150 parameters

julia> lay(xs) |> size
(96, 96, 6, 50)
Expand All @@ -306,7 +313,7 @@ end
DepthwiseConv(weight::AbstractArray, bias, [activation; stride, pad, dilation])

Constructs a layer with the given weight and bias arrays.
Accepts the same keywords as the `DepthwiseConv((4,4), 3=>6, relu)` method.
Accepts the same keywords as the `DepthwiseConv((4,4), 3 => 6, relu)` method.
"""
function DepthwiseConv(w::AbstractArray{T,N}, bias = true, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
Expand All @@ -327,7 +334,7 @@ end
@functor DepthwiseConv

"""
depthwiseconvfilter(filter::Tuple, in=>out)
depthwiseconvfilter(filter::Tuple, in => out)

Constructs a depthwise convolutional weight array defined by `filter` and channels
from `in` to `out`.
Expand All @@ -348,8 +355,8 @@ end

function Base.show(io::IO, l::DepthwiseConv)
print(io, "DepthwiseConv(", size(l.weight)[1:end-2])
print(io, ", ", size(l.weight)[end], "=>", prod(size(l.weight)[end-1:end]))
l.σ == identity || print(io, ", ", l.σ)
print(io, ", ", size(l.weight)[end], " => ", prod(size(l.weight)[end-1:end]))
_print_conv_opt(io, l)
print(io, ")")
end

Expand All @@ -372,12 +379,12 @@ See also [`Conv`](@ref) for more detailed description of keywords.
julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images

julia> lay = CrossCor((5,5), 3 => 6, relu; bias=false)
CrossCor((5, 5), 3=>6, relu)
CrossCor((5, 5), 3 => 6, relu) # 450 parameters

julia> lay(xs) |> size
(96, 96, 6, 50)

julia> CrossCor((5,5), 3=>7, stride=3, pad=(2,0))(xs) |> size
julia> CrossCor((5,5), 3 => 7, stride=3, pad=(2,0))(xs) |> size
(34, 32, 7, 50)
```
"""
Expand All @@ -394,7 +401,7 @@ end
CrossCor(weight::AbstractArray, [bias, activation; stride, pad, dilation])

Constructs a layer with the given weight and bias arrays.
Accepts the same keywords as the `CrossCor((4,4), 3=>7, relu)` method.
Accepts the same keywords as the `CrossCor((4,4), 3 => 7, relu)` method.
"""
function CrossCor(w::AbstractArray{T,N}, bias = true, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
Expand Down Expand Up @@ -429,8 +436,8 @@ end

function Base.show(io::IO, l::CrossCor)
print(io, "CrossCor(", size(l.weight)[1:ndims(l.weight)-2])
print(io, ", ", size(l.weight, ndims(l.weight)-1), "=>", size(l.weight, ndims(l.weight)))
l.σ == identity || print(io, ", ", l.σ)
print(io, ", ", size(l.weight, ndims(l.weight)-1), " => ", size(l.weight, ndims(l.weight)))
_print_conv_opt(io, l)
print(io, ")")
end

Expand Down Expand Up @@ -529,8 +536,7 @@ See also [`MaxPool`](@ref), [`GlobalMeanPool`](@ref).
```jldoctest
julia> xs = rand(Float32, 100, 100, 3, 50);

julia> m = Chain(Conv((3,3), 3=>7), GlobalMaxPool())
Chain(Conv((3, 3), 3=>7), GlobalMaxPool())
julia> m = Chain(Conv((3,3), 3 => 7), GlobalMaxPool());

julia> m(xs) |> size
(1, 1, 7, 50)
Expand Down Expand Up @@ -567,8 +573,7 @@ by performing mean pooling on the complete (w,h)-shaped feature maps.
```jldoctest
julia> xs = rand(Float32, 100, 100, 3, 50);

julia> m = Chain(Conv((3,3), 3=>7), GlobalMeanPool())
Chain(Conv((3, 3), 3=>7), GlobalMeanPool())
julia> m = Chain(Conv((3,3), 3 => 7), GlobalMeanPool());

julia> m(xs) |> size
(1, 1, 7, 50)
Expand Down Expand Up @@ -611,8 +616,11 @@ See also [`Conv`](@ref), [`MeanPool`](@ref), [`AdaptiveMaxPool`](@ref), [`Global
```jldoctest
julia> xs = rand(Float32, 100, 100, 3, 50); # batch of 50 RGB images

julia> m = Chain(Conv((5, 5), 3=>7, pad=SamePad()), MaxPool((5, 5), pad=SamePad()))
Chain(Conv((5, 5), 3=>7), MaxPool((5, 5), pad=2))
julia> m = Chain(Conv((5, 5), 3 => 7, pad=SamePad()), MaxPool((5, 5), pad=SamePad()))
Chain(
Conv((5, 5), 3 => 7, pad=2), # 532 parameters
MaxPool((5, 5), pad=2),
)

julia> m[1](xs) |> size
(100, 100, 7, 50)
Expand Down Expand Up @@ -674,7 +682,10 @@ See also [`Conv`](@ref), [`MaxPool`](@ref), [`AdaptiveMeanPool`](@ref).
julia> xs = rand(Float32, 100, 100, 3, 50);

julia> m = Chain(Conv((5,5), 3 => 7), MeanPool((5,5), pad=SamePad()))
Chain(Conv((5, 5), 3=>7), MeanPool((5, 5), pad=2))
Chain(
Conv((5, 5), 3 => 7), # 532 parameters
MeanPool((5, 5), pad=2),
)

julia> m[1](xs) |> size
(96, 96, 7, 50)
Expand Down
7 changes: 4 additions & 3 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,8 @@ testmode!(m::BatchNorm, mode=true) =

function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(l.chs)")
l.λ == identity || print(io, ", $(l.λ)")
hasaffine(l) || print(io, ", affine=false")
(l.λ == identity) || print(io, ", $(l.λ)")
hasaffine(l) || print(io, ", affine=false") # ??
print(io, ")")
end

Expand Down Expand Up @@ -443,8 +443,9 @@ testmode!(m::GroupNorm, mode = true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)

function Base.show(io::IO, l::GroupNorm)
# print(io, "GroupNorm($(join(size(l.β), ", "))", ", ", l.G)
print(io, "GroupNorm($(l.chs), $(l.G)")
l.λ == identity || print(io, ", $(l.λ)")
l.λ == identity || print(io, ", ", l.λ)
hasaffine(l) || print(io, ", affine=false")
print(io, ")")
end
Expand Down
110 changes: 110 additions & 0 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@

for T in [
:Chain, :Parallel, :SkipConnection, :Recur # container types
]
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
_big_show(io, x)
elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix
_layer_show(io, x)
else
show(io, x)
end
end
end

function _big_show(io::IO, obj, indent::Int=0)
children = trainable(obj)
if all(_show_leaflike, children)
_layer_show(io, obj, indent)
else
println(io, " "^indent, nameof(typeof(obj)), "(")
for c in children
_big_show(io, c, indent+2)
end
if indent == 0
print(io, ")")
_big_finale(io, obj)
else
println(io, " "^indent, "),")
end
end
end

_show_leaflike(x) = isleaf(x) # mostly follow Functors, except for:
_show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv
_show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LSTMcell
_show_leaflike(::Diagonal) = true # appears inside LayerNorm

for T in [
:Conv, :ConvTranspose, :CrossCor, :DepthwiseConv, :Dense,
:BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm,
]
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
if !get(io, :compact, false)
_layer_show(io, x)
else
show(io, x)
end
end
end

function _layer_show(io::IO, layer, indent::Int=0)
str = sprint(show, layer, context=io)
print(io, " "^indent, str, indent==0 ? "" : ",")
if !isempty(params(layer))
print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str)))
printstyled(io, "# ", underscorise(sum(length, params(layer))), " parameters"; color=:light_black)
nonparam = _childarray_sum(length, layer) - sum(length, params(layer))
if nonparam > 0
printstyled(io, ", plus ", underscorise(nonparam); color=:light_black)
end
_nan_show(io, params(layer))
end
indent==0 || println(io)
end

function _big_finale(io::IO, m)
ps = params(m)
if length(ps) > 2
pars = underscorise(sum(length, ps))
bytes = Base.format_bytes(Base.summarysize(m))
noncnt = _childarray_sum(_->1, m) - length(ps)
if noncnt > 0
nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps))
printstyled(io, " "^09, "# Total: ", length(ps), " trainable arrays, with "; color=:light_black)
println(io, pars, " parameters")
printstyled(io, " "^10, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, summarysize "; color=:light_black)
print(io, bytes)
else
printstyled(io, " "^19, "# Total: ", length(ps), " arrays, "; color=:light_black)
print(io, pars, " parameters, ", bytes)
end
end
end

_childarray_sum(f, x::AbstractArray) = f(x)
_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x))

# utility functions

underscorise(n::Integer) =
join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_')

function _nan_show(io::IO, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Layers typically don't show their arrays, and custom layers should define their own, I don't want to take control of how they show their params.

if !isempty(x) && _all(iszero, x)
printstyled(io, " (all zero)", color=:cyan)
elseif _any(isnan, x)
printstyled(io, " (some NaN)", color=:red)
elseif _any(isinf, x)
printstyled(io, " (some Inf)", color=:red)
end
end

_any(f, xs::AbstractArray{<:Number}) = any(f, xs)
# _any(f, xs::Union{Tuple,NamedTuple,Zygote.Params}) = any(x -> _any(f, x), xs)
_any(f, xs) = any(x -> _any(f, x), xs)
_any(f, x::Number) = f(x)
# _any(f, x) = false

_all(f, xs) = !_any(!f, xs)
Loading