Skip to content

Commit bf6d97b

Browse files
committed
refactor multi_layer_feed_forward to be more useful
1 parent 22be907 commit bf6d97b

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

src/utils.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,22 @@
1-
function multi_layer_feed_forward(input_length, output_length; width::Int = 5,
2-
depth::Int = 1, activation = tanh)
3-
Lux.Chain(Lux.Dense(input_length, width, activation),
4-
[Lux.Dense(width, width, activation) for _ in 1:(depth)]...,
5-
Lux.Dense(width, output_length))
1+
"""
2+
multi_layer_feed_forward(; n_input, n_output, width::Int = 4,
3+
depth::Int = 1, activation = tanh, use_bias = true, initial_scaling_factor = 1e-8)
4+
5+
Create a Lux.jl `Chain` for use in [`NeuralNetworkBlock`](@ref)s. The weights of the last layer
6+
are multipled by the `initial_scaling_factor` in order to make the initial contribution
7+
of the network small and thus help with acheiving a stable starting position for the training.
8+
"""
9+
function multi_layer_feed_forward(; n_input, n_output, width::Int = 4,
10+
depth::Int = 1, activation = tanh, use_bias = true, initial_scaling_factor = 1e-8)
11+
Lux.Chain(
12+
Lux.Dense(n_input, width, activation; use_bias),
13+
[Lux.Dense(width, width, activation; use_bias) for _ in 1:(depth)]...,
14+
Lux.Dense(width, n_output;
15+
init_weight = (rng, a...) -> initial_scaling_factor *
16+
Lux.kaiming_uniform(rng, a...), use_bias)
17+
)
18+
end
19+
20+
function multi_layer_feed_forward(n_input, n_output; kwargs...)
21+
multi_layer_feed_forward(; n_input, n_output, kwargs...)
622
end

0 commit comments

Comments
 (0)