Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/Fluxperimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ include("chain.jl")

include("compact.jl")

include("noshow.jl")
export NoShow

include("new_recur.jl")

end # module Fluxperimental
40 changes: 2 additions & 38 deletions src/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,6 @@ for epoch in 1:1000
Flux.train!((m,x,y) -> (m(x) - y)^2, model, data, optim)
end
```

You may also specify a `name` for the model, which will
be used instead of the default printout, which gives a verbatim
representation of the code used to construct the model:

```
model = @compact(w=rand(3), name="Linear(3 => 1)") do x
sum(w .* x)
end
println(model) # "Linear(3 => 1)"
```

This can be useful when using `@compact` to hierarchically construct
complex models to be used inside a `Chain`.
"""
macro compact(_exs...)
# check inputs, extracting function expression fex and unprocessed keyword arguments _kwexs
Expand All @@ -108,16 +94,6 @@ macro compact(_exs...)
kwexs2 = map(ex -> Expr(:kw, ex.args...), _kwexs) # handle keyword arguments provided before semicolon
kwexs = (kwexs1..., kwexs2...)

# check if user has named layer:
name = findfirst(ex -> ex.args[1] == :name, kwexs)
if name !== nothing && kwexs[name].args[2] !== nothing
length(kwexs) == 1 && error("expects keyword arguments")
name_str = kwexs[name].args[2]
# remove name from kwexs (a tuple)
kwexs = (kwexs[1:name-1]..., kwexs[name+1:end]...)
name = name_str
end

# make strings
layer = "@compact"
setup = NamedTuple(map(ex -> Symbol(string(ex.args[1])) => string(ex.args[2]), kwexs))
Expand All @@ -136,7 +112,7 @@ macro compact(_exs...)
fex = supportself(fex, vars)

# assemble
return esc(:($CompactLayer($fex, $name, ($layer, $input, $block), $setup; $(kwexs...))))
return esc(:($CompactLayer($fex, ($layer, $input, $block), $setup; $(kwexs...))))
end

function supportself(fex::Expr, vars)
Expand All @@ -155,12 +131,11 @@ end

struct CompactLayer{F,NT1<:NamedTuple,NT2<:NamedTuple}
fun::F
name::Union{String,Nothing}
strings::NTuple{3,String}
setup_strings::NT1
variables::NT2
end
CompactLayer(f::Function, name::Union{String,Nothing}, str::Tuple, setup_str::NamedTuple; kw...) = CompactLayer(f, name, str, setup_str, NamedTuple(kw))
CompactLayer(f::Function, str::Tuple, setup_str::NamedTuple; kw...) = CompactLayer(f, str, setup_str, NamedTuple(kw))
(m::CompactLayer)(x...) = m.fun(m.variables, x...)
CompactLayer(args...) = error("CompactLayer is meant to be constructed by the macro")
Flux.@functor CompactLayer
Expand All @@ -179,16 +154,6 @@ end

function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
setup_strings = obj.setup_strings
local_name = obj.name
has_explicit_name = local_name !== nothing
if has_explicit_name
if indent != 0 || length(Flux.params(obj)) <= 2
_just_show_params(io, local_name, obj, indent)
else # indent == 0
print(io, local_name)
Flux._big_finale(io, obj)
end
else # no name, so print normally
layer, input, block = obj.strings
pre, post = ("(", ")")
println(io, " "^indent, isnothing(name) ? "" : "$name = ", layer, pre)
Expand Down Expand Up @@ -220,7 +185,6 @@ function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
else
println(io, ",")
end
end
end

# Modified from src/layers/show.jl
Expand Down
62 changes: 62 additions & 0 deletions src/noshow.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@

"""
NoShow(layer)
NoShow(string, layer)

This alters printing (for instance at the REPL prompt) to let you hide the complexity
of some part of a Flux model. It has no effect on the actual running of the model.

By default it prints `NoShow(...)` instead of the given layer.
If you provide a string, it prints that instead -- it can be anything,
but it may make sense to print the name of a function which will
Copy link
Contributor

@gaurav-arya gaurav-arya Aug 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a callable function to reconstruct the layer is what's desired, I thought a bit about allowing the user to specify the function (and its args) and incorporating something like https://github.com/JuliaLang/julia/blob/197180d8589ad14fc4bc4c23782b76739c4ec5a4/base/show.jl#L522 to make this more robust. I don't think it is worth the implementation complexity, and could also easily be added later if we really wanted it, so just noting for posterity.

re-create the same structure.

# Examples

```jldoctest
julia> Chain(Dense(2 => 3), NoShow(Parallel(vcat, Dense(3 => 4), Dense(3 => 5))), Dense(9 => 10))
Chain(
Dense(2 => 3), # 9 parameters
NoShow(...), # 36 parameters
Dense(9 => 10), # 100 parameters
) # Total: 8 arrays, 145 parameters, 1.191 KiB.

julia> pseudolayer((i,o)::Pair) = NoShow(
"pseudolayer(\$i => \$o)",
Parallel(+, Dense(i => o, relu), Dense(i => o, tanh)),
)
pseudolayer (generic function with 1 method)

julia> Chain(Dense(2 => 3), pseudolayer(3 => 10), Dense(9 => 10))
Chain(
Dense(2 => 3), # 9 parameters
pseudolayer(3 => 10), # 80 parameters
Dense(9 => 10), # 100 parameters
) # Total: 8 arrays, 189 parameters, 1.379 KiB.
```
"""
struct NoShow{T}
str::String
layer::T
end

NoShow(layer) = NoShow("NoShow(...)", layer)

Flux.@functor NoShow

(no::NoShow)(x...) = no.layer(x...)

Base.show(io::IO, no::NoShow) = print(io, no.str)

Flux._show_leaflike(::NoShow) = true # I think this is right
Flux._show_children(::NoShow) = (;) # Seems to be needed?

function Base.show(io::IO, ::MIME"text/plain", m::NoShow)
if get(io, :typeinfo, nothing) === nothing # e.g., top level of REPL
Flux._big_show(io, m)
elseif !get(io, :compact, false) # e.g., printed inside a Vector, but not a matrix
Flux._layer_show(io, m)
else
show(io, m)
end
end
56 changes: 11 additions & 45 deletions test/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ end
(1, 128),
(1,),
]
@test size(model(randn(n_in, 32))) == (1, 32)
@test size(model(randn(Float32, n_in, 32))) == (1, 32)
end

@testset "String representations" begin
Expand All @@ -118,15 +118,6 @@ end
@test similar_strings(get_model_string(model), expected_string)
end

@testset "Custom naming" begin
model = @compact(w=Dense(32, 32), name="Linear(...)") do x, y
tmp = sum(w(x))
return tmp + y
end
expected_string = "Linear(...) # 1_056 parameters"
@test similar_strings(get_model_string(model), expected_string)
end

@testset "Hierarchical models" begin
model1 = @compact(w1=Dense(32=>32, relu), w2=Dense(32=>32, relu)) do x
w2(w1(x))
Expand Down Expand Up @@ -161,41 +152,6 @@ end
@test similar_strings(get_model_string(model), expected_string)
end

@testset "Hierarchy with inner model named" begin
model = @compact(
w1=@compact(w1=randn(32, 32), name="Model(32)") do x
w1 * x
end,
w2=randn(32, 32),
w3=randn(32),
) do x
w2 * w1(x)
end
expected_string = """@compact(
Model(32), # 1_024 parameters
w2 = randn(32, 32), # 1_024 parameters
w3 = randn(32), # 32 parameters
) do x
w2 * w1(x)
end # Total: 3 arrays, 2_080 parameters, 17.089 KiB."""
@test similar_strings(get_model_string(model), expected_string)
end

@testset "Hierarchy with outer model named" begin
model = @compact(
w1=@compact(w1=randn(32, 32)) do x
w1 * x
end,
w2=randn(32, 32),
w3=randn(32),
name="Model(32)"
) do x
w2 * w1(x)
end
expected_string = """Model(32) # Total: 3 arrays, 2_080 parameters, 17.057KiB."""
@test similar_strings(get_model_string(model), expected_string)
end

@testset "Dependent initializations" begin
# Test that initialization lines cannot depend on each other
@test_throws UndefVarError @compact(y = 3, z = y^2) do x
Expand Down Expand Up @@ -234,3 +190,13 @@ end
end
end


@testset "Custom naming of @compact with NoShow" begin
_model = @compact(w=Dense(32, 32)) do x, y
tmp = sum(w(x))
return tmp + y
end
model = NoShow(_model)
expected_string = "NoShow(...) # 1_056 parameters"
@test similar_strings(get_model_string(model), expected_string)
end
28 changes: 28 additions & 0 deletions test/noshow.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@

@testset "NoShow" begin
d23 = Dense(2 => 3)
d34 = Dense(3 => 4, tanh)
d35 = Dense(3 => 5, relu)
d910 = Dense(9 => 10)

model = Chain(d23, Parallel(vcat, d34, d35), d910)
m_no = Chain(d23, NoShow(Parallel(vcat, d34, NoShow("zzz", d35))), d910)

@test sum(length, Flux.params(model)) == sum(length, Flux.params(m_no))

xin = randn(Float32, 2, 7)
@test model(xin) ≈ m_no(xin)

# gradients
grad = gradient(m -> m(xin)[1], model)[1]
g_no = gradient(m -> m(xin)[1], m_no)[1]

@test grad.layers[2].layers[1].bias ≈ g_no.layers[2].layer.layers[1].bias
@test grad.layers[2].layers[2].bias ≈ g_no.layers[2].layer.layers[2].layer.bias

# printing -- see also compact.jl for another test
@test !contains(string(model), "NoShow(...)")
@test contains(string(m_no), "NoShow(...)")
@test !contains(string(m_no), "3 => 4")
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using Flux, Fluxperimental
include("chain.jl")

include("compact.jl")
include("noshow.jl")

include("new_recur.jl")

Expand Down