|
1 | | -function multi_layer_feed_forward(input_length, output_length; width::Int = 5, |
2 | | - depth::Int = 1, activation = tanh) |
3 | | - Lux.Chain(Lux.Dense(input_length, width, activation), |
4 | | - [Lux.Dense(width, width, activation) for _ in 1:(depth)]..., |
5 | | - Lux.Dense(width, output_length)) |
| 1 | +""" |
| 2 | + multi_layer_feed_forward(; n_input, n_output, width::Int = 4, |
| 3 | + depth::Int = 1, activation = tanh, use_bias = true, initial_scaling_factor = 1e-8) |
| 4 | +
|
| 5 | +Create a Lux.jl `Chain` for use in [`NeuralNetworkBlock`](@ref)s. The weights of the last layer |
| 6 | +are multipled by the `initial_scaling_factor` in order to make the initial contribution |
| 7 | +of the network small and thus help with acheiving a stable starting position for the training. |
| 8 | +""" |
| 9 | +function multi_layer_feed_forward(; n_input, n_output, width::Int = 4, |
| 10 | + depth::Int = 1, activation = tanh, use_bias = true, initial_scaling_factor = 1e-8) |
| 11 | + Lux.Chain( |
| 12 | + Lux.Dense(n_input, width, activation; use_bias), |
| 13 | + [Lux.Dense(width, width, activation; use_bias) for _ in 1:(depth)]..., |
| 14 | + Lux.Dense(width, n_output; |
| 15 | + init_weight = (rng, a...) -> initial_scaling_factor * |
| 16 | + Lux.kaiming_uniform(rng, a...), use_bias) |
| 17 | + ) |
| 18 | +end |
| 19 | + |
| 20 | +function multi_layer_feed_forward(n_input, n_output; kwargs...) |
| 21 | + multi_layer_feed_forward(; n_input, n_output, kwargs...) |
6 | 22 | end |
0 commit comments