|
1 | 1 | """ |
2 | | - DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16), |
3 | | - branch_activation = identity, trunk_activation = identity) |
| 2 | + DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16), |
| 3 | + branch_activation = identity, trunk_activation = identity) |
4 | 4 |
|
5 | 5 | Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and |
6 | 6 | `trunk` are same. |
@@ -44,25 +44,25 @@ Trunk net : |
44 | 44 | ) |
45 | 45 | ``` |
46 | 46 | """ |
47 | | -function DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16), |
48 | | - branch_activation = identity, trunk_activation = identity) |
| 47 | +function DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), |
| 48 | + branch_activation=identity, trunk_activation=identity) |
49 | 49 |
|
50 | | - # checks for last dimension size |
51 | | - @argcheck branch[end] == trunk[end] "Branch and Trunk net must share the same amount of \ |
52 | | - nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \ |
53 | | - work." |
| 50 | + # checks for last dimension size |
| 51 | + @argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \ |
| 52 | + nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \ |
| 53 | + work." |
54 | 54 |
|
55 | | - branch_net = Chain([Dense(branch[i] => branch[i+1], branch_activation) |
56 | | - for i in 1:(length(branch)-1)]...) |
| 55 | + branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation) |
| 56 | + for i in 1:(length(branch) - 1)]...) |
57 | 57 |
|
58 | | - trunk_net = Chain([Dense(trunk[i] => trunk[i+1], trunk_activation) |
59 | | - for i in 1:(length(trunk)-1)]...) |
| 58 | + trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation) |
| 59 | + for i in 1:(length(trunk) - 1)]...) |
60 | 60 |
|
61 | | - return DeepONet(branch_net, trunk_net) |
| 61 | + return DeepONet(branch_net, trunk_net) |
62 | 62 | end |
63 | 63 |
|
64 | 64 | """ |
65 | | - DeepONet(branch, trunk) |
| 65 | + DeepONet(branch, trunk) |
66 | 66 |
|
67 | 67 | Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the |
68 | 68 | nets output should have the same first dimension. |
@@ -107,28 +107,28 @@ Trunk net : |
107 | 107 | ``` |
108 | 108 | """ |
109 | 109 | function DeepONet(branch::L1, trunk::L2) where {L1, L2} |
110 | | - return @compact(; branch, trunk, dispatch = :DeepONet) do (u, y) # ::AbstractArray{<:Real, M} where {M} |
111 | | - t = trunk(y) # p x N x nb |
112 | | - b = branch(u) # p x nb |
113 | | - |
114 | | - # checks for last dimension size |
115 | | - @argcheck size(t, 1) == size(b, 1) "Branch and Trunk net must share the same amount \ |
116 | | - of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ \ |
117 | | - won't work." |
118 | | - |
119 | | - tᵀ = permutedims(t, (2, 1, 3)) # N x p x nb |
120 | | - b_ = permutedims(reshape(b, size(b)..., 1), (1, 3, 2)) # p x 1 x nb |
121 | | - G = batched_mul(tᵀ, b_) # N x 1 X nb |
122 | | - @return dropdims(G; dims = 2) |
123 | | - end |
| 110 | + return @compact(; branch, trunk, dispatch=:DeepONet) do (u, y) # ::AbstractArray{<:Real, M} where {M} |
| 111 | + t = trunk(y) # p x N x nb |
| 112 | + b = branch(u) # p x nb |
| 113 | + |
| 114 | + # checks for last dimension size |
| 115 | + @argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same amount \ |
| 116 | + of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ \ |
| 117 | + won't work." |
| 118 | + |
| 119 | + tᵀ = permutedims(t, (2, 1, 3)) # N x p x nb |
| 120 | + b_ = permutedims(reshape(b, size(b)..., 1), (1, 3, 2)) # p x 1 x nb |
| 121 | + G = batched_mul(tᵀ, b_) # N x 1 X nb |
| 122 | + @return dropdims(G; dims=2) |
| 123 | + end |
124 | 124 | end |
125 | 125 |
|
126 | 126 | function Base.show(io::IO, model::Lux.CompactLuxLayer{:DeepONet}) |
127 | | - Lux._print_wrapper_model(io, "Branch net :\n", model.layers.branch) |
128 | | - print(io, "\n \n") |
129 | | - Lux._print_wrapper_model(io, "Trunk net :\n", model.layers.trunk) |
| 127 | + Lux._print_wrapper_model(io, "Branch net :\n", model.layers.branch) |
| 128 | + print(io, "\n \n") |
| 129 | + Lux._print_wrapper_model(io, "Trunk net :\n", model.layers.trunk) |
130 | 130 | end |
131 | 131 |
|
132 | 132 | function Base.show(io::IO, ::MIME"text/plain", x::CompactLuxLayer{:DeepONet}) |
133 | | - show(io, x) |
| 133 | + show(io, x) |
134 | 134 | end |
0 commit comments