Skip to content

Commit cd63565

Browse files
committed
add fully connect nn constructor
1 parent 574e257 commit cd63565

File tree

1 file changed

+48
-1
lines changed

1 file changed

+48
-1
lines changed

src/flows/utils.jl

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ using Bijectors: transformed
22
using Flux
33

44
"""
5+
mlp3(input_dim::Int, hidden_dims::Int, output_dim::Int; activation=Flux.leakyrelu)
6+
57
A simple wrapper for a 3 layer dense MLP
68
"""
79
function mlp3(input_dim::Int, hidden_dims::Int, output_dim::Int; activation=Flux.leakyrelu)
@@ -12,7 +14,52 @@ function mlp3(input_dim::Int, hidden_dims::Int, output_dim::Int; activation=Flux
1214
)
1315
end
1416

17+
"""
18+
fnn(
19+
input_dim::Int,
20+
hidden_dims::AbstractVector{<:Int},
21+
output_dim::Int;
22+
inlayer_activation=Flux.leakyrelu,
23+
output_activation=Flux.tanh,
24+
)
25+
26+
Create a fully connected neural network (FNN).
27+
28+
# Arguments
29+
- `input_dim::Int`: The dimension of the input layer.
30+
- `hidden_dims::AbstractVector{<:Int}`: A vector of integers specifying the dimensions of the hidden layers.
31+
- `output_dim::Int`: The dimension of the output layer.
32+
- `inlayer_activation`: The activation function for the hidden layers. Defaults to `Flux.leakyrelu`.
33+
- `output_activation`: The activation function for the output layer. Defaults to `Flux.tanh`.
34+
35+
# Returns
36+
- A `Flux.Chain` representing the FNN.
37+
"""
38+
function fnn(
39+
input_dim::Int,
40+
hidden_dims::AbstractVector{<:Int},
41+
output_dim::Int;
42+
inlayer_activation=Flux.leakyrelu,
43+
output_activation=Flux.tanh,
44+
)
45+
# Create a chain of dense layers
46+
# First layer
47+
layers = Any[Flux.Dense(input_dim, hidden_dims[1], inlayer_activation)]
48+
49+
# Hidden layers
50+
for i in 1:(length(hidden_dims) - 1)
51+
push!(
52+
layers,
53+
Flux.Dense(hidden_dims[i], hidden_dims[i + 1], inlayer_activation),
54+
)
55+
end
56+
57+
# Output layer
58+
push!(layers, Flux.Dense(hidden_dims[end], output_dim, output_activation))
59+
return Chain(layers...)
60+
end
61+
1562
function create_flow(Ls, q₀)
1663
ts = reduce(, Ls)
1764
return transformed(q₀, ts)
18-
end
65+
end

0 commit comments

Comments
 (0)