Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit ffa4a73

Browse files
committed
Implementation for DeepONet
1 parent 434770a commit ffa4a73

File tree

7 files changed

+244
-0
lines changed

7 files changed

+244
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1010
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1111
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1212
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
13+
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
1314
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1415
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1516

src/DeepONet.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""
2+
`DeepONet(architecture_branch::Tuple, architecture_trunk::Tuple,
3+
act_branch = identity, act_trunk = identity;
4+
init_branch = Flux.glorot_uniform,
5+
init_trunk = Flux.glorot_uniform,
6+
bias_branch=true, bias_trunk=true)`
7+
`DeepONet(branch_net::Flux.Chain, trunk_net::Flux.Chain)`
8+
9+
Create an (unstacked) DeepONet architecture as proposed by Lu et al.
10+
arXiv:1910.03193
11+
12+
The model works as follows:
13+
14+
x --- branch --
15+
|
16+
-⊠--u-
17+
|
18+
y --- trunk ---
19+
20+
Where `x` represents the input function, discretely evaluated at its respective sensors.
21+
So the ipnut is of shape [m] for one instance or [m x b] for a training set.
22+
`y` are the probing locations for the operator to be trained. It has shape [N x n] for
23+
N different variables in the PDE (i.e. spatial and temporal coordinates) with each n distinct evaluation points.
24+
`u` is the solution of the queried instance of the PDE, given by the specific choice of parameters.
25+
26+
Both inputs `x` and `y` are multiplied together via dot product Σᵢ bᵢⱼ tᵢₖ.
27+
28+
You can set up this architecture in two ways:
29+
30+
1. By Specifying the architecture and all its parameters as given above. This always creates
31+
`Dense` layers for the branch and trunk net and corresponds to the DeepONet proposed by Lu et al.
32+
33+
2. By passing two architectures in the form of two Chain structs directly. Do this if you want more
34+
flexibility and e.g. use an RNN or CNN instead of simple `Dense` layers.
35+
36+
Strictly speaking, DeepONet does not imply either of the branch or trunk net to be a simple
37+
DNN. Usually though, this is the case which is why it's treated as the default case here.
38+
39+
# Example
40+
41+
Consider a transient 1D advection problem ∂ₜu + u ⋅ ∇u = 0, with an IC u(x,0) = g(x).
42+
We are given several (b = 200) instances of the IC, discretized at 50 points each and want
43+
to query the solution for 100 different locations and times [0;1].
44+
45+
That makes the branch input of shape [50 x 200] and the trunk input of shape [2 x 100]. So the
46+
input for the branch net is 50 and 100 for the trunk net.
47+
48+
# Usage
49+
50+
```julia
51+
julia> model = DeepONet((32,64,72), (24,64,72))
52+
DeepONet with
53+
branch net: (Chain(Dense(32, 64), Dense(64, 72)))
54+
Trunk net: (Chain(Dense(24, 64), Dense(64, 72)))
55+
56+
julia> model = DeepONet((32,64,72), (24,64,72), σ, tanh; init_branch=Flux.glorot_normal, bias_trunk=false)
57+
DeepONet with
58+
branch net: (Chain(Dense(32, 64, σ), Dense(64, 72, σ)))
59+
Trunk net: (Chain(Dense(24, 64, tanh; bias=false), Dense(64, 72, tanh; bias=false)))
60+
61+
julia> branch = Chain(Dense(2,128),Dense(128,64),Dense(64,72))
62+
Chain(
63+
Dense(2, 128), # 384 parameters
64+
Dense(128, 64), # 8_256 parameters
65+
Dense(64, 72), # 4_680 parameters
66+
) # Total: 6 arrays, 13_320 parameters, 52.406 KiB.
67+
68+
julia> trunk = Chain(Dense(1,24),Dense(24,72))
69+
Chain(
70+
Dense(1, 24), # 48 parameters
71+
Dense(24, 72), # 1_800 parameters
72+
) # Total: 4 arrays, 1_848 parameters, 7.469 KiB.
73+
74+
julia> model = DeepONet(branch,trunk)
75+
DeepONet with
76+
branch net: (Chain(Dense(2, 128), Dense(128, 64), Dense(64, 72)))
77+
Trunk net: (Chain(Dense(1, 24), Dense(24, 72)))
78+
```
79+
"""
80+
struct DeepONet
81+
branch_net::Flux.Chain
82+
trunk_net::Flux.Chain
83+
end
84+
85+
# Declare the function that assigns Weights and biases to the layer
86+
function DeepONet(architecture_branch::Tuple, architecture_trunk::Tuple,
87+
act_branch = identity, act_trunk = identity;
88+
init_branch = Flux.glorot_uniform,
89+
init_trunk = Flux.glorot_uniform,
90+
bias_branch=true, bias_trunk=true)
91+
92+
@assert architecture_branch[end] == architecture_trunk[end] "Branch and Trunk net must share the same amount of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't work."
93+
94+
# To construct the subnets we use the helper function in subnets.jl
95+
# Initialize the branch net
96+
branch_net = construct_subnet(architecture_branch, act_branch;
97+
init=init_branch, bias=bias_branch)
98+
# Initialize the trunk net
99+
trunk_net = construct_subnet(architecture_trunk, act_trunk;
100+
init=init_trunk, bias=bias_trunk)
101+
102+
return DeepONet(branch_net, trunk_net)
103+
end
104+
105+
Flux.@functor DeepONet
106+
107+
#= The actual layer that does stuff
108+
x is the input function, evaluated at m locations (or m x b in case of batches)
109+
y is the array of sensors, i.e. the variables of the output function
110+
with shape (N x n) - N different variables with each n evaluation points =#
111+
function (a::DeepONet)(x::AbstractArray, y::AbstractVecOrMat)
112+
# Assign the parameters
113+
branch, trunk = a.branch_net, a.trunk_net
114+
115+
#= Dot product needs a dim to contract
116+
However, we perform the transformations by the NNs always in the first dim
117+
so we need to adjust (i.e. transpose) one of the inputs,
118+
which we do on the branch input here =#
119+
return branch(x)' * trunk(y)
120+
end
121+
122+
# Sensors stay the same and shouldn't be batched
123+
(a::DeepONet)(x::AbstractArray, y::AbstractArray) =
124+
throw(ArgumentError("Sensor locations fed to trunk net can't be batched."))
125+
126+
# Print nicely
127+
function Base.show(io::IO, l::DeepONet)
128+
print(io, "DeepONet with\nbranch net: (",l.branch_net)
129+
print(io, ")\n")
130+
print(io, "Trunk net: (", l.trunk_net)
131+
print(io, ")\n")
132+
end

src/NeuralOperators.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ module NeuralOperators
88
using Zygote
99
using ChainRulesCore
1010

11+
export DeepONet
12+
1113
include("fourier.jl")
1214
include("model.jl")
15+
include("DeepONet.jl")
16+
include("subnets.jl")
1317
end

src/subnets.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""
2+
Construct a Chain of `Dense` layers from a given tuple of integers.
3+
4+
Input:
5+
A tuple (m,n,o,p) of integer type numbers that each describe the width of the i-th Dense layer to Construct
6+
7+
Output:
8+
A `Flux` Chain with length of the input tuple and individual width given by the tuple elements
9+
10+
# Example
11+
12+
```julia
13+
julia> model = NeuralOperators.construct_subnet((2,128,64,32,1))
14+
Chain(
15+
Dense(2, 128), # 384 parameters
16+
Dense(128, 64), # 8_256 parameters
17+
Dense(64, 32), # 2_080 parameters
18+
Dense(32, 1), # 33 parameters
19+
) # Total: 8 arrays, 10_753 parameters, 42.504 KiB.
20+
21+
julia> model([2,1])
22+
1-element Vector{Float32}:
23+
-0.7630446
24+
```
25+
"""
26+
function construct_subnet(architecture::Tuple, σ = identity;
27+
init=Flux.glorot_uniform, bias=true)
28+
# First, create an array that contains all Dense layers independently
29+
# Given n-element architecture constructs n-1 layers
30+
layers = Array{Flux.Dense}(undef, length(architecture)-1)
31+
@inbounds for i 2:length(architecture)
32+
layers[i-1] = Flux.Dense(architecture[i-1], architecture[i], σ;
33+
init=init, bias=bias)
34+
end
35+
36+
# Concatenate the layers to a string, chain them and parse them into
37+
# the Flux Chain constructor syntax
38+
return Meta.parse("Chain("*join(layers,",")*")") |> eval
39+
end

test/burgerset.mat

4.69 MB
Binary file not shown.

test/deeponet.jl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
using Test, Random, Flux, MAT
2+
3+
@testset "DeepONet" begin
4+
@testset "dimensions" begin
5+
# Test the proper construction
6+
# Branch net
7+
@test size(DeepONet((32,64,72), (24,48,72), σ, tanh).branch_net.layers[end].weight) == (72,64)
8+
@test size(DeepONet((32,64,72), (24,48,72), σ, tanh).branch_net.layers[end].bias) == (72,)
9+
# Trunk net
10+
@test size(DeepONet((32,64,72), (24,48,72), σ, tanh).trunk_net.layers[end].weight) == (72,48)
11+
@test size(DeepONet((32,64,72), (24,48,72), σ, tanh).trunk_net.layers[end].bias) == (72,)
12+
end
13+
14+
# Accept only Int as architecture parameters
15+
@test_throws MethodError DeepONet((32.5,64,72), (24,48,72), σ, tanh)
16+
@test_throws MethodError DeepONet((32,64,72), (24.1,48,72))
17+
end
18+
19+
#Just the first 16 datapoints from the Burgers' equation dataset
20+
a = [0.83541104, 0.83479851, 0.83404712, 0.83315711, 0.83212979, 0.83096755, 0.82967374, 0.82825263, 0.82670928, 0.82504949, 0.82327962, 0.82140651, 0.81943734, 0.81737952, 0.8152405, 0.81302771]
21+
sensors = collect(range(0, 1, length=16))'
22+
23+
model = DeepONet((16, 22, 30), (1, 16, 24, 30), σ, tanh; init_branch=Flux.glorot_normal, bias_trunk=false)
24+
25+
model(a,sensors)
26+
27+
#forward pass
28+
@test size(model(a, sensors)) == (1, 16)
29+
30+
mgrad = Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors)
31+
32+
#gradients
33+
@test !iszero(Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors)[1])
34+
@test !iszero(Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors)[2])
35+
36+
#training
37+
#dataset containing first 300 initial conditions from the Burgers' equation
38+
#dataset used by Li et al. for Fourier neural operator. The data for the initial
39+
#conditions is sampled at an interval of 8 points, so, the original data has
40+
#2048 ICs at 8192 points, while here we have 300 ICs at 1024 points
41+
vars = matread("burgerset.mat")
42+
43+
xtrain = vars["a"][1:280, :]'
44+
xval = vars["a"][end-19:end, :]'
45+
46+
ytrain = vars["u"][1:280, :]
47+
yval = vars["u"][end-19:end, :]
48+
49+
grid = collect(range(0, 1, length=1024))'
50+
model = DeepONet((1024,1024,1024),(1,1024,1024),gelu,gelu)
51+
52+
learning_rate = 0.001
53+
opt = ADAM(learning_rate)
54+
55+
parameters = params(model)
56+
57+
loss(xtrain,ytrain,sensor) = Flux.Losses.mse(model(xtrain,sensor),ytrain)
58+
59+
evalcb() = @show(loss(xval,yval,grid))
60+
61+
Flux.@epochs 400 Flux.train!(loss, parameters, [(xtrain,ytrain,grid)], opt, cb = evalcb)
62+
63+
= model(xval, grid)
64+
65+
diffvec = vec(abs.((yval .- ỹ)))
66+
mean_diff = sum(diffvec)/length(diffvec)
67+
@test mean_diff < 0.4

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Flux
55
@testset "NeuralOperators.jl" begin
66
include("fourier.jl")
77
include("model.jl")
8+
include("deeponet.jl")
89
end
910

1011
#=

0 commit comments

Comments
 (0)