|
| 1 | +""" |
| 2 | +`DeepONet(architecture_branch::Tuple, architecture_trunk::Tuple, |
| 3 | + act_branch = identity, act_trunk = identity; |
| 4 | + init_branch = Flux.glorot_uniform, |
| 5 | + init_trunk = Flux.glorot_uniform, |
| 6 | + bias_branch=true, bias_trunk=true)` |
| 7 | +`DeepONet(branch_net::Flux.Chain, trunk_net::Flux.Chain)` |
| 8 | +
|
| 9 | +Create an (unstacked) DeepONet architecture as proposed by Lu et al. |
| 10 | +arXiv:1910.03193 |
| 11 | +
|
| 12 | +The model works as follows: |
| 13 | +
|
| 14 | +x --- branch -- |
| 15 | + | |
| 16 | + -⊠--u- |
| 17 | + | |
| 18 | +y --- trunk --- |
| 19 | +
|
| 20 | +Where `x` represents the input function, discretely evaluated at its respective sensors. |
| 21 | +So the ipnut is of shape [m] for one instance or [m x b] for a training set. |
| 22 | +`y` are the probing locations for the operator to be trained. It has shape [N x n] for |
| 23 | +N different variables in the PDE (i.e. spatial and temporal coordinates) with each n distinct evaluation points. |
| 24 | +`u` is the solution of the queried instance of the PDE, given by the specific choice of parameters. |
| 25 | +
|
| 26 | +Both inputs `x` and `y` are multiplied together via dot product Σᵢ bᵢⱼ tᵢₖ. |
| 27 | +
|
| 28 | +You can set up this architecture in two ways: |
| 29 | +
|
| 30 | +1. By Specifying the architecture and all its parameters as given above. This always creates |
| 31 | + `Dense` layers for the branch and trunk net and corresponds to the DeepONet proposed by Lu et al. |
| 32 | +
|
| 33 | +2. By passing two architectures in the form of two Chain structs directly. Do this if you want more |
| 34 | +flexibility and e.g. use an RNN or CNN instead of simple `Dense` layers. |
| 35 | +
|
| 36 | +Strictly speaking, DeepONet does not imply either of the branch or trunk net to be a simple |
| 37 | + DNN. Usually though, this is the case which is why it's treated as the default case here. |
| 38 | +
|
| 39 | +# Example |
| 40 | +
|
| 41 | +Consider a transient 1D advection problem ∂ₜu + u ⋅ ∇u = 0, with an IC u(x,0) = g(x). |
| 42 | +We are given several (b = 200) instances of the IC, discretized at 50 points each and want |
| 43 | + to query the solution for 100 different locations and times [0;1]. |
| 44 | +
|
| 45 | +That makes the branch input of shape [50 x 200] and the trunk input of shape [2 x 100]. So the |
| 46 | + input for the branch net is 50 and 100 for the trunk net. |
| 47 | +
|
| 48 | +# Usage |
| 49 | +
|
| 50 | +```julia |
| 51 | +julia> model = DeepONet((32,64,72), (24,64,72)) |
| 52 | +DeepONet with |
| 53 | +branch net: (Chain(Dense(32, 64), Dense(64, 72))) |
| 54 | +Trunk net: (Chain(Dense(24, 64), Dense(64, 72))) |
| 55 | +
|
| 56 | +julia> model = DeepONet((32,64,72), (24,64,72), σ, tanh; init_branch=Flux.glorot_normal, bias_trunk=false) |
| 57 | +DeepONet with |
| 58 | +branch net: (Chain(Dense(32, 64, σ), Dense(64, 72, σ))) |
| 59 | +Trunk net: (Chain(Dense(24, 64, tanh; bias=false), Dense(64, 72, tanh; bias=false))) |
| 60 | +
|
| 61 | +julia> branch = Chain(Dense(2,128),Dense(128,64),Dense(64,72)) |
| 62 | +Chain( |
| 63 | + Dense(2, 128), # 384 parameters |
| 64 | + Dense(128, 64), # 8_256 parameters |
| 65 | + Dense(64, 72), # 4_680 parameters |
| 66 | +) # Total: 6 arrays, 13_320 parameters, 52.406 KiB. |
| 67 | +
|
| 68 | +julia> trunk = Chain(Dense(1,24),Dense(24,72)) |
| 69 | +Chain( |
| 70 | + Dense(1, 24), # 48 parameters |
| 71 | + Dense(24, 72), # 1_800 parameters |
| 72 | +) # Total: 4 arrays, 1_848 parameters, 7.469 KiB. |
| 73 | +
|
| 74 | +julia> model = DeepONet(branch,trunk) |
| 75 | +DeepONet with |
| 76 | +branch net: (Chain(Dense(2, 128), Dense(128, 64), Dense(64, 72))) |
| 77 | +Trunk net: (Chain(Dense(1, 24), Dense(24, 72))) |
| 78 | +``` |
| 79 | +""" |
| 80 | +struct DeepONet |
| 81 | + branch_net::Flux.Chain |
| 82 | + trunk_net::Flux.Chain |
| 83 | +end |
| 84 | + |
| 85 | +# Declare the function that assigns Weights and biases to the layer |
| 86 | +function DeepONet(architecture_branch::Tuple, architecture_trunk::Tuple, |
| 87 | + act_branch = identity, act_trunk = identity; |
| 88 | + init_branch = Flux.glorot_uniform, |
| 89 | + init_trunk = Flux.glorot_uniform, |
| 90 | + bias_branch=true, bias_trunk=true) |
| 91 | + |
| 92 | + @assert architecture_branch[end] == architecture_trunk[end] "Branch and Trunk net must share the same amount of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't work." |
| 93 | + |
| 94 | + # To construct the subnets we use the helper function in subnets.jl |
| 95 | + # Initialize the branch net |
| 96 | + branch_net = construct_subnet(architecture_branch, act_branch; |
| 97 | + init=init_branch, bias=bias_branch) |
| 98 | + # Initialize the trunk net |
| 99 | + trunk_net = construct_subnet(architecture_trunk, act_trunk; |
| 100 | + init=init_trunk, bias=bias_trunk) |
| 101 | + |
| 102 | + return DeepONet(branch_net, trunk_net) |
| 103 | +end |
| 104 | + |
| 105 | +Flux.@functor DeepONet |
| 106 | + |
| 107 | +#= The actual layer that does stuff |
| 108 | +x is the input function, evaluated at m locations (or m x b in case of batches) |
| 109 | +y is the array of sensors, i.e. the variables of the output function |
| 110 | +with shape (N x n) - N different variables with each n evaluation points =# |
| 111 | +function (a::DeepONet)(x::AbstractArray, y::AbstractVecOrMat) |
| 112 | + # Assign the parameters |
| 113 | + branch, trunk = a.branch_net, a.trunk_net |
| 114 | + |
| 115 | + #= Dot product needs a dim to contract |
| 116 | + However, we perform the transformations by the NNs always in the first dim |
| 117 | + so we need to adjust (i.e. transpose) one of the inputs, |
| 118 | + which we do on the branch input here =# |
| 119 | + return branch(x)' * trunk(y) |
| 120 | +end |
| 121 | + |
| 122 | +# Sensors stay the same and shouldn't be batched |
| 123 | +(a::DeepONet)(x::AbstractArray, y::AbstractArray) = |
| 124 | + throw(ArgumentError("Sensor locations fed to trunk net can't be batched.")) |
| 125 | + |
| 126 | +# Print nicely |
| 127 | +function Base.show(io::IO, l::DeepONet) |
| 128 | + print(io, "DeepONet with\nbranch net: (",l.branch_net) |
| 129 | + print(io, ")\n") |
| 130 | + print(io, "Trunk net: (", l.trunk_net) |
| 131 | + print(io, ")\n") |
| 132 | +end |
0 commit comments