Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 67d9007

Browse files
authored
Merge pull request #5 from ayushinav/deeponet
Deeponet
2 parents 8d34b77 + ebf9e87 commit 67d9007

File tree

4 files changed

+172
-0
lines changed

4 files changed

+172
-0
lines changed

src/LuxNeuralOperators.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@ include("functional.jl")
2525
include("layers.jl")
2626

2727
include("fno.jl")
28+
include("deeponet.jl")
2829

2930
export FourierTransform
3031
export SpectralConv, OperatorConv, SpectralKernel, OperatorKernel
3132
export FourierNeuralOperator
33+
export DeepONet
3234

3335
end

src/deeponet.jl

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""
2+
DeepONet(; branch = (64, 32, 32, 16), trunk = (1, 8, 8, 16),
3+
branch_activation = identity, trunk_activation = identity)
4+
5+
Constructs a DeepONet composed of Dense layers. Make sure the last node of `branch` and
6+
`trunk` are same.
7+
8+
## Keyword arguments:
9+
10+
- `branch`: Tuple of integers containing the number of nodes in each layer for branch net
11+
- `trunk`: Tuple of integers containing the number of nodes in each layer for trunk net
12+
- `branch_activation`: activation function for branch net
13+
- `trunk_activation`: activation function for trunk net
14+
15+
## References
16+
17+
[1] Lu Lu, Pengzhan Jin, George Em Karniadakis, "DeepONet: Learning nonlinear operators for
18+
identifying differential equations based on the universal approximation theorem of
19+
operators", doi: https://arxiv.org/abs/1910.03193
20+
21+
## Example
22+
23+
```jldoctest
24+
deeponet = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))
25+
26+
# output
27+
28+
Branch net :
29+
(
30+
Chain(
31+
layer_1 = Dense(64 => 32), # 2_080 parameters
32+
layer_2 = Dense(32 => 32), # 1_056 parameters
33+
layer_3 = Dense(32 => 16), # 528 parameters
34+
),
35+
)
36+
37+
Trunk net :
38+
(
39+
Chain(
40+
layer_1 = Dense(1 => 8), # 16 parameters
41+
layer_2 = Dense(8 => 8), # 72 parameters
42+
layer_3 = Dense(8 => 16), # 144 parameters
43+
),
44+
)
45+
```
46+
"""
47+
function DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16),
48+
branch_activation=identity, trunk_activation=identity)
49+
50+
# checks for last dimension size
51+
@argcheck branch[end]==trunk[end] "Branch and Trunk net must share the same amount of \
52+
nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ won't \
53+
work."
54+
55+
branch_net = Chain([Dense(branch[i] => branch[i + 1], branch_activation)
56+
for i in 1:(length(branch) - 1)]...)
57+
58+
trunk_net = Chain([Dense(trunk[i] => trunk[i + 1], trunk_activation)
59+
for i in 1:(length(trunk) - 1)]...)
60+
61+
return DeepONet(branch_net, trunk_net)
62+
end
63+
64+
"""
65+
DeepONet(branch, trunk)
66+
67+
Constructs a DeepONet from a `branch` and `trunk` architectures. Make sure that both the
68+
nets output should have the same first dimension.
69+
70+
## Arguments
71+
72+
- `branch`: `Lux` network to be used as branch net.
73+
- `trunk`: `Lux` network to be used as trunk net.
74+
75+
## References
76+
77+
[1] Lu Lu, Pengzhan Jin, George Em Karniadakis, "DeepONet: Learning nonlinear operators for
78+
identifying differential equations based on the universal approximation theorem of
79+
operators", doi: https://arxiv.org/abs/1910.03193
80+
81+
## Example
82+
83+
```jldoctest
84+
branch_net = Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16));
85+
trunk_net = Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16));
86+
don_ = DeepONet(branch_net, trunk_net)
87+
88+
# output
89+
90+
Branch net :
91+
(
92+
Chain(
93+
layer_1 = Dense(64 => 32), # 2_080 parameters
94+
layer_2 = Dense(32 => 32), # 1_056 parameters
95+
layer_3 = Dense(32 => 16), # 528 parameters
96+
),
97+
)
98+
99+
Trunk net :
100+
(
101+
Chain(
102+
layer_1 = Dense(1 => 8), # 16 parameters
103+
layer_2 = Dense(8 => 8), # 72 parameters
104+
layer_3 = Dense(8 => 16), # 144 parameters
105+
),
106+
)
107+
```
108+
"""
109+
function DeepONet(branch::L1, trunk::L2) where {L1, L2}
110+
return @compact(; branch, trunk, dispatch=:DeepONet) do (u, y) # ::AbstractArray{<:Real, M} where {M}
111+
t = trunk(y) # p x N x nb
112+
b = branch(u) # p x nb
113+
114+
# checks for last dimension size
115+
@argcheck size(t, 1)==size(b, 1) "Branch and Trunk net must share the same amount \
116+
of nodes in the last layer. Otherwise Σᵢ bᵢⱼ tᵢₖ \
117+
won't work."
118+
119+
tᵀ = permutedims(t, (2, 1, 3)) # N x p x nb
120+
b_ = permutedims(reshape(b, size(b)..., 1), (1, 3, 2)) # p x 1 x nb
121+
G = batched_mul(tᵀ, b_) # N x 1 X nb
122+
@return dropdims(G; dims=2)
123+
end
124+
end
125+
126+
function Base.show(io::IO, model::Lux.CompactLuxLayer{:DeepONet})
127+
Lux._print_wrapper_model(io, "Branch net :\n", model.layers.branch)
128+
print(io, "\n \n")
129+
Lux._print_wrapper_model(io, "Trunk net :\n", model.layers.trunk)
130+
end
131+
132+
function Base.show(io::IO, ::MIME"text/plain", x::CompactLuxLayer{:DeepONet})
133+
show(io, x)
134+
end

src/functional.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
@inline function operator_conv(x, tform::AbstractTransform, weights)
23
x_t = transform(tform, x)
34
x_tr = truncate_modes(tform, x_t)

test/deeponet_tests.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
@testitem "DeepONet" setup=[SharedTestSetup] begin
2+
@testset "BACKEND: $(mode)" for (mode, aType, dev, ongpu) in MODES
3+
rng_ = get_stable_rng()
4+
5+
u = rand(64, 5) |> aType # sensor_points x nb
6+
y = rand(1, 10, 5) |> aType # ndims x N x nb
7+
out_size = (10, 5)
8+
9+
don_ = DeepONet(; branch=(64, 32, 32, 16), trunk=(1, 8, 8, 16))
10+
11+
ps, st = Lux.setup(rng_, don_) |> dev
12+
13+
@inferred don_((u, y), ps, st)
14+
@jet don_((u, y), ps, st)
15+
16+
pred = first(don_((u, y), ps, st))
17+
@test size(pred) == out_size
18+
19+
don_ = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 16)),
20+
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
21+
22+
ps, st = Lux.setup(rng_, don_) |> dev
23+
24+
@inferred don_((u, y), ps, st)
25+
@jet don_((u, y), ps, st)
26+
27+
pred = first(don_((u, y), ps, st))
28+
@test size(pred) == out_size
29+
30+
don_ = DeepONet(Chain(Dense(64 => 32), Dense(32 => 32), Dense(32 => 20)),
31+
Chain(Dense(1 => 8), Dense(8 => 8), Dense(8 => 16)))
32+
ps, st = Lux.setup(rng_, don_) |> dev
33+
@test_throws ArgumentError don_((u, y), ps, st)
34+
end
35+
end

0 commit comments

Comments
 (0)