-
Notifications
You must be signed in to change notification settings - Fork 83
Closed
Labels
enhancementNew feature or requestNew feature or request
Description
A simple implementation:
struct BroadcastLayer{T <: NamedTuple} <: AbstractExplicitContainerLayer{(:layers,)}
layers::T
end
function BroadcastLayer(layers...)
for l in layers
if !iszero(statelength(l))
throw(ArgumentError("Stateful layer `$l` are not supported for `BroadcastLayer`."))
end
end
names = ntuple(i -> Symbol("layer_$i"), length(layers))
return BroadcastLayer(NamedTuple{names}(layers))
end
BroadcastLayer(; kwargs...) = BroadcastLayer(connection, (; kwargs...))
function (m::BroadcastLayer)(x, ps, st::NamedTuple{names}) where {names}
results = (first ∘ Lux.apply).(values(m.layers), x, values(ps), values(st))
return results, st
end
Base.keys(m::BroadcastLayer) = Base.keys(getfield(m, :layers))Originally posted by @avik-pal in #282 (comment)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request