diff --git a/src/layers/show.jl b/src/layers/show.jl index 0ae14dd9ee..3fb8fb0d78 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -93,7 +93,8 @@ function _big_finale(io::IO, m) if length(ps) > 2 pars = underscorise(sum(length, ps; init=0)) bytes = Base.format_bytes(Base.summarysize(m)) - noncnt = _childarray_sum(_->1, m) - length(ps) + unique_params = IdSet() + noncnt = _childarray_sum(_ -> 1, m, unique_params) - length(ps) if noncnt > 0 nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps; init=0)) printstyled(io, " "^08, "# Total: ", length(ps), " trainable arrays, "; color=:light_black) @@ -110,6 +111,17 @@ end _childarray_sum(f, x::AbstractArray{<:Number}) = f(x) _childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x), init=0) +_childarray_sum(f, x::AbstractArray{<:Number}, idset::Base.IdSet) = f(x) +function _childarray_sum(f, x, idset::Base.IdSet) + isleaf(x) && return 0 + + if x in idset + return 0 + else + push!(idset, x) + return sum(y -> _childarray_sum(f, y, idset), Functors.children(x), init = 0) + end +end # utility functions