-
-
Notifications
You must be signed in to change notification settings - Fork 6
Remove @compact(name=...) and replace with NoShow
#19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
8a1832a
6734e89
57b0858
78cec03
75744f9
f9a28cd
38d6b02
5adf64e
345ef7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
|
|
||
| """ | ||
| NoShow(layer) | ||
| NoShow(string, layer) | ||
|
|
||
| This alters printing (for instance at the REPL prompt) to let you hide the complexity | ||
| of some part of a Flux model. It has no effect on the actual running of the model. | ||
|
|
||
| By default it prints `NoShow(...)` instead of the given layer. | ||
| If you provide a string, it prints that instead -- it can be anything, | ||
| but it may make sense to print the name of a function which will | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When a callable function to reconstruct the layer is what's desired, I thought a bit about allowing the user to specify the function (and its args) and incorporating something like https://github.com/JuliaLang/julia/blob/197180d8589ad14fc4bc4c23782b76739c4ec5a4/base/show.jl#L522 to make this more robust. I don't think it is worth the implementation complexity, and could also easily be added later if we really wanted it, so just noting for posterity. |
||
| re-create the same structure. | ||
|
|
||
| # Examples | ||
|
|
||
| ```jldoctest | ||
| julia> Chain(Dense(2 => 3), NoShow(Parallel(vcat, Dense(3 => 4), Dense(3 => 5))), Dense(9 => 10)) | ||
| Chain( | ||
| Dense(2 => 3), # 9 parameters | ||
| NoShow(...), # 36 parameters | ||
| Dense(9 => 10), # 100 parameters | ||
| ) # Total: 8 arrays, 145 parameters, 1.191 KiB. | ||
|
|
||
| julia> pseudolayer((i,o)::Pair) = NoShow( | ||
| "pseudolayer(\$i => \$o)", | ||
| Parallel(+, Dense(i => o, relu), Dense(i => o, tanh)), | ||
| ) | ||
| pseudolayer (generic function with 1 method) | ||
|
|
||
| julia> Chain(Dense(2 => 3), pseudolayer(3 => 10), Dense(9 => 10)) | ||
| Chain( | ||
| Dense(2 => 3), # 9 parameters | ||
| pseudolayer(3 => 10), # 80 parameters | ||
| Dense(9 => 10), # 100 parameters | ||
| ) # Total: 8 arrays, 189 parameters, 1.379 KiB. | ||
| ``` | ||
| """ | ||
| struct NoShow{T} | ||
| str::String | ||
| layer::T | ||
| end | ||
|
|
||
| NoShow(layer) = NoShow("NoShow(...)", layer) | ||
|
|
||
| Flux.@functor NoShow | ||
|
|
||
| (no::NoShow)(x...) = no.layer(x...) | ||
|
|
||
| Base.show(io::IO, no::NoShow) = print(io, no.str) | ||
|
|
||
| Flux._show_leaflike(::NoShow) = true # I think this is right | ||
| Flux._show_children(::NoShow) = (;) # Seems to be needed? | ||
|
|
||
| function Base.show(io::IO, ::MIME"text/plain", m::NoShow) | ||
| if get(io, :typeinfo, nothing) === nothing # e.g., top level of REPL | ||
| Flux._big_show(io, m) | ||
| elseif !get(io, :compact, false) # e.g., printed inside a Vector, but not a matrix | ||
| Flux._layer_show(io, m) | ||
| else | ||
| show(io, m) | ||
| end | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
|
|
||
| @testset "NoShow" begin | ||
| d23 = Dense(2 => 3) | ||
| d34 = Dense(3 => 4, tanh) | ||
| d35 = Dense(3 => 5, relu) | ||
| d910 = Dense(9 => 10) | ||
|
|
||
| model = Chain(d23, Parallel(vcat, d34, d35), d910) | ||
| m_no = Chain(d23, NoShow(Parallel(vcat, d34, NoShow("zzz", d35))), d910) | ||
|
|
||
| @test sum(length, Flux.params(model)) == sum(length, Flux.params(m_no)) | ||
|
|
||
| xin = randn(Float32, 2, 7) | ||
| @test model(xin) ≈ m_no(xin) | ||
|
|
||
| # gradients | ||
| grad = gradient(m -> m(xin)[1], model)[1] | ||
| g_no = gradient(m -> m(xin)[1], m_no)[1] | ||
|
|
||
| @test grad.layers[2].layers[1].bias ≈ g_no.layers[2].layer.layers[1].bias | ||
| @test grad.layers[2].layers[2].bias ≈ g_no.layers[2].layer.layers[2].layer.bias | ||
|
|
||
| # printing -- see also compact.jl for another test | ||
| @test !contains(string(model), "NoShow(...)") | ||
| @test contains(string(m_no), "NoShow(...)") | ||
| @test !contains(string(m_no), "3 => 4") | ||
| end | ||
|
|
Uh oh!
There was an error while loading. Please reload this page.