|
| 1 | +""" |
| 2 | + DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16), |
| 3 | + branch_activation = identity, trunk_activation = identity) |
| 4 | +
|
| 5 | +Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and |
| 6 | +`trunk` are same. |
| 7 | +
|
| 8 | +## Keyword arguments: |
| 9 | +
|
| 10 | + - `branch`: Tuple of integers containing the number of nodes in each layer for branch net |
| 11 | + - `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net |
| 12 | + - `branch_activation`: activation function for branch net |
| 13 | + - `trunk_activation`: activation function for trunk net |
| 14 | +
|
| 15 | +## References |
| 16 | +
|
| 17 | +[1] Lu Lu, Pengzhan Jin, George Em Karniadakis, "DeepONet: Learning nonlinear operators for |
| 18 | +identifying differential equations based on the universal approximation theorem of |
| 19 | +operators", doi: https://arxiv.org/abs/1910.03193 |
| 20 | +
|
| 21 | +## Example |
| 22 | +
|
| 23 | +```jldoctest |
| 24 | +deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16)) |
| 25 | +
|
| 26 | +# output |
| 27 | +
|
| 28 | +Branch net : |
| 29 | +( |
| 30 | + Chain( |
| 31 | + layer_1 = Dense(64 => 32), # 2_080 parameters |
| 32 | + layer_2 = Dense(32 => 32), # 1_056 parameters |
| 33 | + layer_3 = Dense(32 => 16), # 528 parameters |
| 34 | + ), |
| 35 | +) |
| 36 | +
|
| 37 | +Trunk net : |
| 38 | +( |
| 39 | + Chain( |
| 40 | + layer_1 = Dense(1 => 8), # 16 parameters |
| 41 | + layer_2 = Dense(8 => 8), # 72 parameters |
| 42 | + layer_3 = Dense(8 => 16), # 144 parameters |
| 43 | + ), |
| 44 | +) |
| 45 | +``` |
| 46 | +""" |
| 47 | +function DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16), |
| 48 | + branch_activation=identity, trunk_activation=identity) |
| 49 | + |
| 50 | + # checks for last dimension size |
| 51 | + @argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \ |
| 52 | + nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \ |
| 53 | + work." |
| 54 | + |
| 55 | + branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation) |
| 56 | + for i in 1:(length(branch) - 1)]...) |
| 57 | + |
| 58 | + trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation) |
| 59 | + for i in 1:(length(trunk) - 1)]...) |
| 60 | + |
| 61 | + return DeepONet(branch_net, trunk_net) |
| 62 | +end |
| 63 | + |
| 64 | +""" |
| 65 | + DeepONet(branch, trunk) |
| 66 | +
|
| 67 | +Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the |
| 68 | +nets output should have the same first dimension. |
| 69 | +
|
| 70 | +## Arguments |
| 71 | +
|
| 72 | + - `branch`: `Lux` network to be used as branch net. |
| 73 | + - `trunk`: `Lux` network to be used as trunk net. |
| 74 | +
|
| 75 | +## References |
| 76 | +
|
| 77 | +[1] Lu Lu, Pengzhan Jin, George Em Karniadakis, "DeepONet: Learning nonlinear operators for |
| 78 | +identifying differential equations based on the universal approximation theorem of |
| 79 | +operators", doi: https://arxiv.org/abs/1910.03193 |
| 80 | +
|
| 81 | +## Example |
| 82 | +
|
| 83 | +```jldoctest |
| 84 | +branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16)); |
| 85 | +trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)); |
| 86 | +don_ = DeepONet(branch_net, trunk_net) |
| 87 | +
|
| 88 | +# output |
| 89 | +
|
| 90 | +Branch net : |
| 91 | +( |
| 92 | + Chain( |
| 93 | + layer_1 = Dense(64 => 32), # 2_080 parameters |
| 94 | + layer_2 = Dense(32 => 32), # 1_056 parameters |
| 95 | + layer_3 = Dense(32 => 16), # 528 parameters |
| 96 | + ), |
| 97 | +) |
| 98 | +
|
| 99 | +Trunk net : |
| 100 | +( |
| 101 | + Chain( |
| 102 | + layer_1 = Dense(1 => 8), # 16 parameters |
| 103 | + layer_2 = Dense(8 => 8), # 72 parameters |
| 104 | + layer_3 = Dense(8 => 16), # 144 parameters |
| 105 | + ), |
| 106 | +) |
| 107 | +``` |
| 108 | +""" |
| 109 | +function DeepONet(branch::L1, trunk::L2) where {L1, L2} |
| 110 | + return @compact(; branch, trunk, dispatch=:DeepONet) do (u, y) # ::AbstractArray{<:Real, M} where {M} |
| 111 | + t = trunk(y) # p x N x nb |
| 112 | + b = branch(u) # p x nb |
| 113 | + |
| 114 | + # checks for last dimension size |
| 115 | + @argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same amount \ |
| 116 | + of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ \ |
| 117 | + won't work." |
| 118 | + |
| 119 | + tᵀ = permutedims(t, (2, 1, 3)) # N x p x nb |
| 120 | + b_ = permutedims(reshape(b, size(b)..., 1), (1, 3, 2)) # p x 1 x nb |
| 121 | + G = batched_mul(tᵀ, b_) # N x 1 X nb |
| 122 | + @return dropdims(G; dims=2) |
| 123 | + end |
| 124 | +end |
| 125 | + |
| 126 | +function Base.show(io::IO, model::Lux.CompactLuxLayer{:DeepONet}) |
| 127 | + Lux._print_wrapper_model(io, "Branch net :\n", model.layers.branch) |
| 128 | + print(io, "\n \n") |
| 129 | + Lux._print_wrapper_model(io, "Trunk net :\n", model.layers.trunk) |
| 130 | +end |
| 131 | + |
| 132 | +function Base.show(io::IO, ::MIME"text/plain", x::CompactLuxLayer{:DeepONet}) |
| 133 | + show(io, x) |
| 134 | +end |
0 commit comments