diff --git a/src/compact.jl b/src/compact.jl index 3b8e17f..a258ddf 100644 --- a/src/compact.jl +++ b/src/compact.jl @@ -1,13 +1,13 @@ import Flux: _big_show """ - @compact(forward::Function; name=nothing, parameters...) + @compact(forward::Function[, layer_type]; name=nothing, parameters...) Creates a layer by specifying some `parameters`, in the form of keywords, and (usually as a `do` block) a function for the forward pass. You may think of `@compact` as a specialized `let` block creating local variables -that are trainable in Flux. -Declared variable names may be used within the body of the `forward` function. +that are trainable in Flux. Declared variable names may be used within the +body of the `forward` function. Here is a linear model: @@ -28,7 +28,7 @@ end d(ones(5, 10)) # 7×10 Matrix as output. ``` -Finally, here is a simple MLP: +Here is a simple MLP: ``` using Flux @@ -78,11 +78,33 @@ println(model) # "Linear(3 => 1)" This can be useful when using `@compact` to hierarchically construct complex models to be used inside a `Chain`. + +You can also specify a symbol to identify the type of layer, which is +useful for dispatching on types of layers (since the function block +will generate a new type each time it is evaluated): + +``` +model = @compact(MyLayer, w=rand(3)) do x + sum(w .* x) +end + +f(::CompactLayer{:MyLayer}) = 1 +f(::CompactLayer{:Default}) = 0 + +println(f(model)) # 1 +``` """ macro compact(fex, kwexs...) # check input Meta.isexpr(fex, :(->)) || error("expects a do block") isempty(kwexs) && error("expects keyword arguments") + # Check if first kwexes is just a Symbol: + (layer_symbol, kwexs) = if first(kwexs) isa Symbol + (first(kwexs), Base.tail(kwexs)) + else + (:Default, kwexs) + end + layer_symbol = QuoteNode(layer_symbol) all(ex -> Meta.isexpr(ex, (:kw,:(=))), kwexs) || error("expects only keyword argumens") # check if user has named layer: @@ -112,7 +134,7 @@ macro compact(fex, kwexs...) return esc(quote let $(assigns...) - $CompactLayer($fex, $name, ($layer, $input, $block), $setup; $(vars...)) + $CompactLayer(Val($layer_symbol), $fex, $name, ($layer, $input, $block), $setup; $(vars...)) end end) end @@ -128,17 +150,21 @@ function addprefix!(ex::Expr, self, vars) end addprefix!(not_ex, self, vars) = nothing -struct CompactLayer{F,NT1<:NamedTuple,NT2<:NamedTuple} +struct CompactLayer{S,F,NT1<:NamedTuple,NT2<:NamedTuple} + symbol::Val{S} fun::F name::Union{String,Nothing} strings::NTuple{3,String} setup_strings::NT1 variables::NT2 end -CompactLayer(f::Function, name::Union{String,Nothing}, str::Tuple, setup_str::NamedTuple; kw...) = CompactLayer(f, name, str, setup_str, NamedTuple(kw)) +function CompactLayer(symb::Val, f::Function, name::Union{String,Nothing}, str::Tuple, setup_str::NamedTuple; kw...) + return CompactLayer(symb, f, name, str, setup_str, NamedTuple(kw)) +end (m::CompactLayer)(x...) = m.fun(m.variables, x...) CompactLayer(args...) = error("CompactLayer is meant to be constructed by the macro") Flux.@functor CompactLayer +layer_symbol(::CompactLayer{S}) where {S} = S Flux._show_children(m::CompactLayer) = m.variables @@ -167,6 +193,9 @@ function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing) layer, input, block = obj.strings pre, post = ("(", ")") println(io, " "^indent, isnothing(name) ? "" : "$name = ", layer, pre) + if layer_symbol(obj) != :Default + println(io, " "^(indent+2), string(layer_symbol(obj)), ",") + end for k in keys(obj.variables) v = obj.variables[k] if Flux._show_leaflike(v) diff --git a/test/compact.jl b/test/compact.jl index 06ebe8e..38b1033 100644 --- a/test/compact.jl +++ b/test/compact.jl @@ -1,4 +1,4 @@ -import Fluxperimental: @compact +import Fluxperimental: @compact, CompactLayer # Strip both strings of spaces, and then test: function similar_strings(s1, s2) @@ -182,3 +182,96 @@ end @test similar_strings(get_model_string(model), expected_string) end + +@testset "Dispatch using symbols" begin + model1 = @compact(W=randn(32)) do x + W .* x + end + model2 = @compact(MyCustomLayer, W=randn(32)) do x + W .* x + end + @eval my_custom_function(::CompactLayer{:Default}) = :Default + @eval my_custom_function(::CompactLayer{:MyCustomLayer}) = :MyCustomLayer + + @test model1 isa CompactLayer{:Default} + @test model2 isa CompactLayer{:MyCustomLayer} + @test my_custom_function(model1) == :Default + @test my_custom_function(model2) == :MyCustomLayer + + expected_string2 = """@compact( + MyCustomLayer, + W = randn(32), # 32 parameters + ) do x + W .* x + end""" + @test similar_strings(get_model_string(model2), expected_string2) + + @testset "Nested symbolic layers" begin + num_features = 1 + num_out = 2 + d_attn = 4 + d_value = 12 + num_heads = 3 + + model3 = @compact( + SelfAttention, + out = Dense(num_heads * d_value => num_out), + heads = [ + @compact( + Head, + K = Dense(num_features => d_attn), + V = Dense(num_features => d_value), + Q = Dense(num_features => d_attn) + ) do x + k, v, q = K(x), V(x), Q(x) + x = sum(k .* q; dims=1) ./ sqrt(d_attn) + softmax(x; dims=2) .* v + end for _ in 1:num_heads + ] + ) do x + out(vcat([h(x) for h in heads]...)) + end + @test model3 isa CompactLayer{:SelfAttention} + @test all(t -> isa(t, CompactLayer{:Head}), model3.variables.heads) + + expected_string3 = """@compact( + SelfAttention, + out = Dense(36 => 2), # 74 parameters + heads = Array( + @compact( + Head, + K = Dense(1 => 4), # 8 parameters + V = Dense(1 => 12), # 24 parameters + Q = Dense(1 => 4), # 8 parameters + ) do x + (k, v, q) = (K(x), V(x), Q(x)) + x = sum(k .* q; dims = 1) ./ sqrt(d_attn) + softmax(x; dims = 2) .* v + end, + @compact( + Head, + K = Dense(1 => 4), # 8 parameters + V = Dense(1 => 12), # 24 parameters + Q = Dense(1 => 4), # 8 parameters + ) do x + (k, v, q) = (K(x), V(x), Q(x)) + x = sum(k .* q; dims = 1) ./ sqrt(d_attn) + softmax(x; dims = 2) .* v + end, + @compact( + Head, + K = Dense(1 => 4), # 8 parameters + V = Dense(1 => 12), # 24 parameters + Q = Dense(1 => 4), # 8 parameters + ) do x + (k, v, q) = (K(x), V(x), Q(x)) + x = sum(k .* q; dims = 1) ./ sqrt(d_attn) + softmax(x; dims = 2) .* v + end, + ), + ) do x + out(vcat([h(x) for h = heads]...)) + end # Total: 20 arrays, 194 parameters, 3.515 KiB.""" + @test similar_strings(get_model_string(model3), expected_string3) + end +end